From 52b989c6c8f49428f02661797ec6bdcd93700595 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 23 Jan 2026 09:48:39 +0800 Subject: [PATCH] don't place consts early + fixes from anthropic challenge (#14286) * don't place consts early * add anthropic challenge * with ref * do we still have to devectorize bools? * tests pass * just WHERE * fine, revert that * fine, revert * only index * z3 validator doesn't support vectorized * Revert "z3 validator doesn't support vectorized" This reverts commit 1b7930ecb3f8eb6aab0d0803d06380a5146498f9. * z3 not for vec * no spec * VLIWRenderer * loop unrolling * better comments * cleanups * skip cast * renderer * cleanups * prints * no hack * hacks * bump to 11 * reg warning * lil clean * cleaner renderer --- .github/workflows/benchmark.yml | 2 +- examples/anthropic_challenge.py | 196 ++++++++++++++++++++++++++ tinygrad/codegen/__init__.py | 4 +- tinygrad/codegen/late/devectorizer.py | 14 +- tinygrad/codegen/late/linearizer.py | 1 - tinygrad/uop/spec.py | 3 +- tinygrad/uop/symbolic.py | 4 +- tinygrad/uop/validate.py | 4 +- 8 files changed, 216 insertions(+), 12 deletions(-) create mode 100644 examples/anthropic_challenge.py diff --git a/.github/workflows/benchmark.yml b/.github/workflows/benchmark.yml index da1a281570..d2329a822b 100644 --- a/.github/workflows/benchmark.yml +++ b/.github/workflows/benchmark.yml @@ -561,7 +561,7 @@ jobs: - name: openpilot compile3 0.10.1 driving_policy run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx - name: openpilot compile3 0.10.1 dmonitoring - run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=10 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx + run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx - name: benchmark MobileNetV2 on DSP run: | # generate quantized weights diff --git a/examples/anthropic_challenge.py b/examples/anthropic_challenge.py new file mode 100644 index 0000000000..23f0f27ae6 --- /dev/null +++ b/examples/anthropic_challenge.py @@ -0,0 +1,196 @@ +from tinygrad import Tensor, dtypes, Context, getenv, UOp, fetch +from tinygrad.uop.ops import Ops, PatternMatcher, UPat +from tinygrad.uop.symbolic import symbolic +from tinygrad.codegen import Renderer +from tinygrad.codegen.opt import Opt, OptOps + +# ************************* implementation of the problem ************************ + +def myhash(a: Tensor) -> Tensor: + a = (a + 0x7ED55D16) + (a << 12) + a = (a ^ 0xC761C23C) ^ (a >> 19) + a = (a + 0x165667B1) + (a << 5) + a = (a + 0xD3A2646C) ^ (a << 9) + a = (a + 0xFD7046C5) + (a << 3) + a = (a ^ 0xB55A4F09) ^ (a >> 16) + return a + +def select_with_where_tree(values: Tensor, relative_idx: Tensor) -> Tensor: + n = values.shape[0] + if n == 1: return values[0].expand(relative_idx.shape) + + mid = n // 2 + left = select_with_where_tree(values[:mid], relative_idx) + right = select_with_where_tree(values[mid:], relative_idx - mid) + + go_left = relative_idx < mid + return go_left.where(left, right) + +def tree_traversal(forest: Tensor, val: Tensor, height: int, rounds: int, where_tree_threshold=3) -> Tensor: + # All walkers start at idx=0 + idx = Tensor.zeros(val.shape, device=val.device, dtype=dtypes.uint32) + + for r in range(rounds): + level = r % (height + 1) + level_start = (1 << level) - 1 + level_size = 1 << level + + if level == 0: + # At root (level 0), all walkers are at idx=0 + # No gather needed, just broadcast the root value + node_val = forest[0].expand(val.shape) + idx = idx * 0 # Reset to 0 + elif level <= where_tree_threshold: + # Small level: use where-tree + level_values = forest[level_start : level_start + level_size] + relative_idx = (idx - level_start) + node_val = select_with_where_tree(level_values, relative_idx) + else: + # Large level: use gather + node_val = forest.gather(0, idx) + + val = myhash(val ^ node_val) + idx = (idx << 1) + (1 + (val & 1)) + + # No wrap check needed! At round 10 (level becomes 0), we reset idx above. + + return val.contiguous(arg=(Opt(OptOps.UPCAST, 0, 8),)) + +# ************************* renderer for VLIW machine ************************* + +def loop_unrolling(sink:UOp): + rng = [x for x in sink.toposort() if x.op is Ops.RANGE] + if len(rng) == 0: return None + print(f"unrolling loop with size {rng[0].vmax+1}") + unrolled_sinks = [sink.substitute({rng[0]:rng[0].const_like(i)}).src[0] for i in range(rng[0].vmax+1)] + return UOp.sink(*unrolled_sinks, arg=sink.arg) + +global_addrs = [] +vliw_prepare = PatternMatcher([ + # loop unrolling (should be a part of tinygrad) + (UPat(Ops.SINK, name="sink"), loop_unrolling), + # cast is fake + (UPat(Ops.CAST, name="c"), lambda c: c.src[0]), + # rewrites to hardcode the addresses in memory + (UPat(Ops.DEFINE_GLOBAL, name="dg"), lambda dg: UOp.const(dtypes.uint, global_addrs[dg.arg])), + # INDEX is just plus + (UPat(Ops.INDEX, name="i"), lambda i: i.src[0]+i.src[1]), +])+symbolic + +class VLIWRenderer(Renderer): + has_local = False # TODO: this should be the default / cleaned up + # this says this backend supports MULACC + more. decompositions uses this + code_for_op: dict = {Ops.MULACC: None, Ops.ADD: "+", Ops.MUL: "*", + Ops.XOR: "^", Ops.AND: "&", Ops.OR: "|", + Ops.SHL: "<<", Ops.SHR: ">>", Ops.CMPLT: "<"} + # this matcher runs while still in graph form + pre_matcher = vliw_prepare + + def render(self, uops:list[UOp]): + + # TODO: this is a minimal renderer. for low cycle count, make it good + # to get speed, you need to add VLIW packing + # to get under 1536 regs, you need to add a register allocator + # we left the fun parts to you + + print(f"rendering with {len(uops)} uops") + reg, inst = 0, [] + r: dict[UOp, int] = {} + for u in uops: + assert u.dtype.count in (1,8), "dtype count must be 1 or 8" + + # dumb register allocator + if u.op not in {Ops.STORE, Ops.SINK, Ops.GEP}: + r[u] = reg + reg += u.dtype.count + + # render UOps to instructions + match u.op: + case Ops.SINK: + inst.append({"flow": [("halt",)]}) + case Ops.CONST: + inst.append({"load": [("const", r[u], u.arg)]}) + case Ops.GEP: + # a GEP is just an alias to a special register in the vector + r[u] = r[u.src[0]] + u.arg[0] + case Ops.VECTORIZE: + if all(s == u.src[0] for s in u.src): + # if all sources are the same, we can broadcast + inst.append({"valu": [("vbroadcast", r[u], r[u.src[0]])]}) + else: + # this is a copy into a contiguous chunk of registers + inst.extend({"flow": [("add_imm", r[u]+i, r[s], 0)]} for i,s in enumerate(u.src) if r[s] != r[u]+i) + case Ops.LOAD: + op = "vload" if u.dtype.count > 1 else "load" + inst.append({"load": [(op, r[u], r[u.src[0]])]}) + case Ops.STORE: + op = "vstore" if u.src[1].dtype.count > 1 else "store" + inst.append({"store": [(op, r[u.src[0]], r[u.src[1]])]}) + case Ops.MULACC: + assert u.dtype.count == 8 + inst.append({"valu": [("multiply_add", r[u], r[u.src[0]], r[u.src[1]], r[u.src[2]])]}) + case Ops.WHERE: + assert u.dtype.count == 8 + inst.append({"flow": [("vselect", r[u], r[u.src[0]], r[u.src[1]], r[u.src[2]])]}) + case _ if u.op in self.code_for_op: + cat = "valu" if u.dtype.count > 1 else "alu" + inst.append({cat: [(self.code_for_op[u.op], r[u], r[u.src[0]], r[u.src[1]])]}) + case _: + raise NotImplementedError(f"unhandled op {u.op}") + return repr(inst) + +# ************************* test and render ************************* + +import sys, types +PROBLEM_URL = "https://raw.githubusercontent.com/anthropics/original_performance_takehome/refs/heads/main/tests/frozen_problem.py" +sys.modules["problem"] = problem = types.ModuleType("problem") +exec(fetch(PROBLEM_URL).read_text(), problem.__dict__) + +if __name__ == "__main__": + batch_size = getenv("BS", 256) + height = 10 + rounds = getenv("ROUNDS", 16) + + # build problem + tree = problem.Tree.generate(height) + inp = problem.Input.generate(tree, batch_size, rounds) + mem = problem.build_mem_image(tree, inp) + global_addrs.extend([mem[6], mem[6], mem[4]]) # output, input, forest + + # *** verify the kernel in tinygrad compared to reference *** + + forest_t = Tensor(tree.values, dtype=dtypes.uint32) + val_t = Tensor(inp.values, dtype=dtypes.uint32) + + if getenv("VERIFY", 1): + # verify on normal tinygrad device + with Context(PCONTIG=2): + out = tree_traversal(forest_t, val_t, height, rounds) + val_out = out.tolist() + problem.reference_kernel(tree, inp) + assert val_out == inp.values + print("verification passed") + + # *** render to device *** + + from tinygrad.codegen import get_program + with Context(PCONTIG=2, DEVECTORIZE=2, SPEC=0): + out = tree_traversal(forest_t, val_t, height, rounds) + sink = out.schedule()[-1].ast + prg = get_program(sink, VLIWRenderer()) + + # *** run on Machine and compare *** + + # NOTE: the scratch size needs to be reduced to 1536 when you have a register allocator + src = eval(prg.src) + max_regs = max(t[1] for instr in src for v in instr.values() for t in v if len(t) > 1) + 8 + print(f"{max_regs:5d} regs used" + ("" if max_regs <= 1536 else " <-- WARNING: TOO MANY REGISTERS, MUST BE <= 1536")) + machine = problem.Machine(mem, src, problem.DebugInfo(scratch_map={}), n_cores=1, trace=False, scratch_size=max_regs) + machine.run() + print(f"ran for {machine.cycle:5d} cycles" + ("" if machine.cycle <= 1363 else " <-- EVEN CLAUDE GOT 1363")) + + # compare to reference + ref_mem = mem.copy() + for _ in problem.reference_kernel2(ref_mem, {}): pass + assert machine.mem[mem[6]:mem[6]+mem[2]] == ref_mem[mem[6]:mem[6]+mem[2]] + print("compare passed!") diff --git a/tinygrad/codegen/__init__.py b/tinygrad/codegen/__init__.py index bbab89f6da..24ae316d4a 100644 --- a/tinygrad/codegen/__init__.py +++ b/tinygrad/codegen/__init__.py @@ -84,10 +84,10 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) - if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing - sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize") + if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize") # lower the index dtype to a concrete int - sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing, ctx=ren.device, name="lower all index dtypes") + sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, ctx=ren.device, name="lower all index dtypes") sink = graph_rewrite(sink, symbolic, name="post index symbolic") # optional pre matcher diff --git a/tinygrad/codegen/late/devectorizer.py b/tinygrad/codegen/late/devectorizer.py index 7d7c8ebc7c..d58d3c2e97 100644 --- a/tinygrad/codegen/late/devectorizer.py +++ b/tinygrad/codegen/late/devectorizer.py @@ -243,11 +243,17 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp): def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp): cnt = cast.dtype.count + vcnt = cast.dtype.vcount precnt = bcast.dtype.vcount - input_gep = bcast.arg if bcast.op is Ops.GEP else ([0]*precnt) - gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)])) - sum_arg = tuple(flatten([[i+y for y in input_gep] for i in range(cnt)])) - return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg), ptr=True) + # TODO: I have no idea *why* this is. I just change things until the tests pass. No AI, old school. + if bcast.op is Ops.GEP: + gep_arg = tuple(flatten([range(precnt) for _ in range(vcnt)])) + sum_arg = tuple(flatten([[i+y for y in bcast.arg] for i in range(vcnt)])) + else: + gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)])) + sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)])) + new_idx = idx.gep(gep_arg)*cnt + UOp.const(dtypes.index.vec(len(sum_arg)), sum_arg) + return buf.broadcast(cnt*precnt).index(new_idx, ptr=True) devectorize_buf_and_index = PatternMatcher([ (UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf), diff --git a/tinygrad/codegen/late/linearizer.py b/tinygrad/codegen/late/linearizer.py index 6471ec76f1..8420810f2e 100644 --- a/tinygrad/codegen/late/linearizer.py +++ b/tinygrad/codegen/late/linearizer.py @@ -30,7 +30,6 @@ def linearize(sink:UOp) -> list[UOp]: case Ops.DEFINE_VAR: priority, extra = -19, u.arg case Ops.DEFINE_LOCAL: priority = -18 case Ops.DEFINE_REG: priority = -17 - case Ops.CONST: priority = -10 # early consts case Ops.LOAD: priority = -1 # place loads early case Ops.STORE: priority = 1 # place stores late case Ops.RANGE: priority = 5 # placing RANGE is good diff --git a/tinygrad/uop/spec.py b/tinygrad/uop/spec.py index 84867233fc..5122a5775a 100644 --- a/tinygrad/uop/spec.py +++ b/tinygrad/uop/spec.py @@ -15,8 +15,9 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None): # TODO: validate these # WEBGPU has a BITCAST in the index, PTX casts pointer to long + # VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors for x in idx.toposort() | gate.toposort(): - if x.op is Ops.BITCAST or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True + if x.op in {Ops.BITCAST, Ops.VECTORIZE, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True # if all is good and IGNORE_OOB=0, validate with z3 from tinygrad.uop.validate import validate_index_with_z3 diff --git a/tinygrad/uop/symbolic.py b/tinygrad/uop/symbolic.py index b9b657ebb8..fd100c38d6 100644 --- a/tinygrad/uop/symbolic.py +++ b/tinygrad/uop/symbolic.py @@ -163,8 +163,8 @@ gep_pushing = PatternMatcher([ (UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x), # GEP in order is removed (UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None), - # push all GEPs through ALUs (fix arange stuff) - (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, name='gep'), + # push all GEPs through ALUs for index (TODO: remove this) + (UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, dtype=dtypes.index, name='gep'), lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \ if not isinstance(gep.dtype, PtrDType) and not isinstance(alu.dtype, PtrDType) else None), # CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later) diff --git a/tinygrad/uop/validate.py b/tinygrad/uop/validate.py index 4084b875b9..2590297c6c 100644 --- a/tinygrad/uop/validate.py +++ b/tinygrad/uop/validate.py @@ -48,7 +48,9 @@ def uops_to_z3(solver, *uops: UOp) -> list[z3.ExprRef]: lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.index) or x.op is Ops.SINK))[:-1] z3map: dict[UOp, z3.ExprRef] = {} for i,u in enumerate(lst): - new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_renderer.rewrite(u, ctx=(solver, z3map))) + z3_rewritten = z3_renderer.rewrite(u, ctx=(solver, z3map)) + if z3_rewritten is None: raise NotImplementedError(f"{u.op} is not supported by z3") + new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_rewritten) if constraint is not None: solver.add(constraint) z3map[u] = new_u assert all(u in z3map for u in uops), "UOp failed to rewrite to z3!"