mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
don't place consts early + fixes from anthropic challenge (#14286)
* don't place consts early
* add anthropic challenge
* with ref
* do we still have to devectorize bools?
* tests pass
* just WHERE
* fine, revert that
* fine, revert
* only index
* z3 validator doesn't support vectorized
* Revert "z3 validator doesn't support vectorized"
This reverts commit 1b7930ecb3.
* z3 not for vec
* no spec
* VLIWRenderer
* loop unrolling
* better comments
* cleanups
* skip cast
* renderer
* cleanups
* prints
* no hack
* hacks
* bump to 11
* reg warning
* lil clean
* cleaner renderer
This commit is contained in:
2
.github/workflows/benchmark.yml
vendored
2
.github/workflows/benchmark.yml
vendored
@@ -561,7 +561,7 @@ jobs:
|
||||
- name: openpilot compile3 0.10.1 driving_policy
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_policy PYTHONPATH="." ASSERT_MIN_STEP_TIME=3 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/driving_policy.onnx
|
||||
- name: openpilot compile3 0.10.1 dmonitoring
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=10 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
run: BENCHMARK_LOG=openpilot_0_10_1_dmonitoring PYTHONPATH="." ASSERT_MIN_STEP_TIME=11 DEV=QCOM FLOAT16=1 IMAGE=2 NOLOCALS=1 taskset -c 4-7 python3 examples/openpilot/compile3.py https://github.com/commaai/openpilot/raw/720392c9a5b986981fdbed1bb8c47a6c5573a50e/selfdrive/modeld/models/dmonitoring_model.onnx
|
||||
- name: benchmark MobileNetV2 on DSP
|
||||
run: |
|
||||
# generate quantized weights
|
||||
|
||||
196
examples/anthropic_challenge.py
Normal file
196
examples/anthropic_challenge.py
Normal file
@@ -0,0 +1,196 @@
|
||||
from tinygrad import Tensor, dtypes, Context, getenv, UOp, fetch
|
||||
from tinygrad.uop.ops import Ops, PatternMatcher, UPat
|
||||
from tinygrad.uop.symbolic import symbolic
|
||||
from tinygrad.codegen import Renderer
|
||||
from tinygrad.codegen.opt import Opt, OptOps
|
||||
|
||||
# ************************* implementation of the problem ************************
|
||||
|
||||
def myhash(a: Tensor) -> Tensor:
|
||||
a = (a + 0x7ED55D16) + (a << 12)
|
||||
a = (a ^ 0xC761C23C) ^ (a >> 19)
|
||||
a = (a + 0x165667B1) + (a << 5)
|
||||
a = (a + 0xD3A2646C) ^ (a << 9)
|
||||
a = (a + 0xFD7046C5) + (a << 3)
|
||||
a = (a ^ 0xB55A4F09) ^ (a >> 16)
|
||||
return a
|
||||
|
||||
def select_with_where_tree(values: Tensor, relative_idx: Tensor) -> Tensor:
|
||||
n = values.shape[0]
|
||||
if n == 1: return values[0].expand(relative_idx.shape)
|
||||
|
||||
mid = n // 2
|
||||
left = select_with_where_tree(values[:mid], relative_idx)
|
||||
right = select_with_where_tree(values[mid:], relative_idx - mid)
|
||||
|
||||
go_left = relative_idx < mid
|
||||
return go_left.where(left, right)
|
||||
|
||||
def tree_traversal(forest: Tensor, val: Tensor, height: int, rounds: int, where_tree_threshold=3) -> Tensor:
|
||||
# All walkers start at idx=0
|
||||
idx = Tensor.zeros(val.shape, device=val.device, dtype=dtypes.uint32)
|
||||
|
||||
for r in range(rounds):
|
||||
level = r % (height + 1)
|
||||
level_start = (1 << level) - 1
|
||||
level_size = 1 << level
|
||||
|
||||
if level == 0:
|
||||
# At root (level 0), all walkers are at idx=0
|
||||
# No gather needed, just broadcast the root value
|
||||
node_val = forest[0].expand(val.shape)
|
||||
idx = idx * 0 # Reset to 0
|
||||
elif level <= where_tree_threshold:
|
||||
# Small level: use where-tree
|
||||
level_values = forest[level_start : level_start + level_size]
|
||||
relative_idx = (idx - level_start)
|
||||
node_val = select_with_where_tree(level_values, relative_idx)
|
||||
else:
|
||||
# Large level: use gather
|
||||
node_val = forest.gather(0, idx)
|
||||
|
||||
val = myhash(val ^ node_val)
|
||||
idx = (idx << 1) + (1 + (val & 1))
|
||||
|
||||
# No wrap check needed! At round 10 (level becomes 0), we reset idx above.
|
||||
|
||||
return val.contiguous(arg=(Opt(OptOps.UPCAST, 0, 8),))
|
||||
|
||||
# ************************* renderer for VLIW machine *************************
|
||||
|
||||
def loop_unrolling(sink:UOp):
|
||||
rng = [x for x in sink.toposort() if x.op is Ops.RANGE]
|
||||
if len(rng) == 0: return None
|
||||
print(f"unrolling loop with size {rng[0].vmax+1}")
|
||||
unrolled_sinks = [sink.substitute({rng[0]:rng[0].const_like(i)}).src[0] for i in range(rng[0].vmax+1)]
|
||||
return UOp.sink(*unrolled_sinks, arg=sink.arg)
|
||||
|
||||
global_addrs = []
|
||||
vliw_prepare = PatternMatcher([
|
||||
# loop unrolling (should be a part of tinygrad)
|
||||
(UPat(Ops.SINK, name="sink"), loop_unrolling),
|
||||
# cast is fake
|
||||
(UPat(Ops.CAST, name="c"), lambda c: c.src[0]),
|
||||
# rewrites to hardcode the addresses in memory
|
||||
(UPat(Ops.DEFINE_GLOBAL, name="dg"), lambda dg: UOp.const(dtypes.uint, global_addrs[dg.arg])),
|
||||
# INDEX is just plus
|
||||
(UPat(Ops.INDEX, name="i"), lambda i: i.src[0]+i.src[1]),
|
||||
])+symbolic
|
||||
|
||||
class VLIWRenderer(Renderer):
|
||||
has_local = False # TODO: this should be the default / cleaned up
|
||||
# this says this backend supports MULACC + more. decompositions uses this
|
||||
code_for_op: dict = {Ops.MULACC: None, Ops.ADD: "+", Ops.MUL: "*",
|
||||
Ops.XOR: "^", Ops.AND: "&", Ops.OR: "|",
|
||||
Ops.SHL: "<<", Ops.SHR: ">>", Ops.CMPLT: "<"}
|
||||
# this matcher runs while still in graph form
|
||||
pre_matcher = vliw_prepare
|
||||
|
||||
def render(self, uops:list[UOp]):
|
||||
|
||||
# TODO: this is a minimal renderer. for low cycle count, make it good
|
||||
# to get speed, you need to add VLIW packing
|
||||
# to get under 1536 regs, you need to add a register allocator
|
||||
# we left the fun parts to you
|
||||
|
||||
print(f"rendering with {len(uops)} uops")
|
||||
reg, inst = 0, []
|
||||
r: dict[UOp, int] = {}
|
||||
for u in uops:
|
||||
assert u.dtype.count in (1,8), "dtype count must be 1 or 8"
|
||||
|
||||
# dumb register allocator
|
||||
if u.op not in {Ops.STORE, Ops.SINK, Ops.GEP}:
|
||||
r[u] = reg
|
||||
reg += u.dtype.count
|
||||
|
||||
# render UOps to instructions
|
||||
match u.op:
|
||||
case Ops.SINK:
|
||||
inst.append({"flow": [("halt",)]})
|
||||
case Ops.CONST:
|
||||
inst.append({"load": [("const", r[u], u.arg)]})
|
||||
case Ops.GEP:
|
||||
# a GEP is just an alias to a special register in the vector
|
||||
r[u] = r[u.src[0]] + u.arg[0]
|
||||
case Ops.VECTORIZE:
|
||||
if all(s == u.src[0] for s in u.src):
|
||||
# if all sources are the same, we can broadcast
|
||||
inst.append({"valu": [("vbroadcast", r[u], r[u.src[0]])]})
|
||||
else:
|
||||
# this is a copy into a contiguous chunk of registers
|
||||
inst.extend({"flow": [("add_imm", r[u]+i, r[s], 0)]} for i,s in enumerate(u.src) if r[s] != r[u]+i)
|
||||
case Ops.LOAD:
|
||||
op = "vload" if u.dtype.count > 1 else "load"
|
||||
inst.append({"load": [(op, r[u], r[u.src[0]])]})
|
||||
case Ops.STORE:
|
||||
op = "vstore" if u.src[1].dtype.count > 1 else "store"
|
||||
inst.append({"store": [(op, r[u.src[0]], r[u.src[1]])]})
|
||||
case Ops.MULACC:
|
||||
assert u.dtype.count == 8
|
||||
inst.append({"valu": [("multiply_add", r[u], r[u.src[0]], r[u.src[1]], r[u.src[2]])]})
|
||||
case Ops.WHERE:
|
||||
assert u.dtype.count == 8
|
||||
inst.append({"flow": [("vselect", r[u], r[u.src[0]], r[u.src[1]], r[u.src[2]])]})
|
||||
case _ if u.op in self.code_for_op:
|
||||
cat = "valu" if u.dtype.count > 1 else "alu"
|
||||
inst.append({cat: [(self.code_for_op[u.op], r[u], r[u.src[0]], r[u.src[1]])]})
|
||||
case _:
|
||||
raise NotImplementedError(f"unhandled op {u.op}")
|
||||
return repr(inst)
|
||||
|
||||
# ************************* test and render *************************
|
||||
|
||||
import sys, types
|
||||
PROBLEM_URL = "https://raw.githubusercontent.com/anthropics/original_performance_takehome/refs/heads/main/tests/frozen_problem.py"
|
||||
sys.modules["problem"] = problem = types.ModuleType("problem")
|
||||
exec(fetch(PROBLEM_URL).read_text(), problem.__dict__)
|
||||
|
||||
if __name__ == "__main__":
|
||||
batch_size = getenv("BS", 256)
|
||||
height = 10
|
||||
rounds = getenv("ROUNDS", 16)
|
||||
|
||||
# build problem
|
||||
tree = problem.Tree.generate(height)
|
||||
inp = problem.Input.generate(tree, batch_size, rounds)
|
||||
mem = problem.build_mem_image(tree, inp)
|
||||
global_addrs.extend([mem[6], mem[6], mem[4]]) # output, input, forest
|
||||
|
||||
# *** verify the kernel in tinygrad compared to reference ***
|
||||
|
||||
forest_t = Tensor(tree.values, dtype=dtypes.uint32)
|
||||
val_t = Tensor(inp.values, dtype=dtypes.uint32)
|
||||
|
||||
if getenv("VERIFY", 1):
|
||||
# verify on normal tinygrad device
|
||||
with Context(PCONTIG=2):
|
||||
out = tree_traversal(forest_t, val_t, height, rounds)
|
||||
val_out = out.tolist()
|
||||
problem.reference_kernel(tree, inp)
|
||||
assert val_out == inp.values
|
||||
print("verification passed")
|
||||
|
||||
# *** render to device ***
|
||||
|
||||
from tinygrad.codegen import get_program
|
||||
with Context(PCONTIG=2, DEVECTORIZE=2, SPEC=0):
|
||||
out = tree_traversal(forest_t, val_t, height, rounds)
|
||||
sink = out.schedule()[-1].ast
|
||||
prg = get_program(sink, VLIWRenderer())
|
||||
|
||||
# *** run on Machine and compare ***
|
||||
|
||||
# NOTE: the scratch size needs to be reduced to 1536 when you have a register allocator
|
||||
src = eval(prg.src)
|
||||
max_regs = max(t[1] for instr in src for v in instr.values() for t in v if len(t) > 1) + 8
|
||||
print(f"{max_regs:5d} regs used" + ("" if max_regs <= 1536 else " <-- WARNING: TOO MANY REGISTERS, MUST BE <= 1536"))
|
||||
machine = problem.Machine(mem, src, problem.DebugInfo(scratch_map={}), n_cores=1, trace=False, scratch_size=max_regs)
|
||||
machine.run()
|
||||
print(f"ran for {machine.cycle:5d} cycles" + ("" if machine.cycle <= 1363 else " <-- EVEN CLAUDE GOT 1363"))
|
||||
|
||||
# compare to reference
|
||||
ref_mem = mem.copy()
|
||||
for _ in problem.reference_kernel2(ref_mem, {}): pass
|
||||
assert machine.mem[mem[6]:mem[6]+mem[2]] == ref_mem[mem[6]:mem[6]+mem[2]]
|
||||
print("compare passed!")
|
||||
@@ -84,10 +84,10 @@ def full_rewrite_to_sink(sink:UOp, ren:Renderer|None=None, optimize:bool=True) -
|
||||
if DEVECTORIZE >= 2: pm_devectorize = sym+load_store_folding+load_store_indexing
|
||||
elif DEVECTORIZE: pm_devectorize = sym+devectorize+load_store_folding+correct_load_store+load_store_indexing
|
||||
else: pm_devectorize = sym+load_store_folding+correct_load_store+load_store_indexing
|
||||
sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
|
||||
if DEVECTORIZE >= 0: sink = graph_rewrite(sink, pm_devectorize, ctx=ren, name="devectorize")
|
||||
|
||||
# lower the index dtype to a concrete int
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing, ctx=ren.device, name="lower all index dtypes")
|
||||
sink = graph_rewrite(sink, pm_lower_index_dtype+load_store_indexing+gep_pushing, ctx=ren.device, name="lower all index dtypes")
|
||||
sink = graph_rewrite(sink, symbolic, name="post index symbolic")
|
||||
|
||||
# optional pre matcher
|
||||
|
||||
@@ -243,11 +243,17 @@ def no_vectorized_index(buf:UOp, cast:UOp, idx:UOp):
|
||||
|
||||
def no_vectorized_index_broadcast(buf:UOp, cast:UOp, bcast:UOp, idx:UOp):
|
||||
cnt = cast.dtype.count
|
||||
vcnt = cast.dtype.vcount
|
||||
precnt = bcast.dtype.vcount
|
||||
input_gep = bcast.arg if bcast.op is Ops.GEP else ([0]*precnt)
|
||||
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
|
||||
sum_arg = tuple(flatten([[i+y for y in input_gep] for i in range(cnt)]))
|
||||
return buf.broadcast(cnt*precnt).index(idx.gep(gep_arg)*cnt+UOp.const(dtypes.index.vec(cnt*precnt), sum_arg), ptr=True)
|
||||
# TODO: I have no idea *why* this is. I just change things until the tests pass. No AI, old school.
|
||||
if bcast.op is Ops.GEP:
|
||||
gep_arg = tuple(flatten([range(precnt) for _ in range(vcnt)]))
|
||||
sum_arg = tuple(flatten([[i+y for y in bcast.arg] for i in range(vcnt)]))
|
||||
else:
|
||||
gep_arg = tuple(flatten([range(precnt) for _ in range(cnt)]))
|
||||
sum_arg = tuple(flatten([[i]*precnt for i in range(cnt)]))
|
||||
new_idx = idx.gep(gep_arg)*cnt + UOp.const(dtypes.index.vec(len(sum_arg)), sum_arg)
|
||||
return buf.broadcast(cnt*precnt).index(new_idx, ptr=True)
|
||||
|
||||
devectorize_buf_and_index = PatternMatcher([
|
||||
(UPat((Ops.DEFINE_LOCAL, Ops.DEFINE_REG), name="buf"), no_vectorized_buf),
|
||||
|
||||
@@ -30,7 +30,6 @@ def linearize(sink:UOp) -> list[UOp]:
|
||||
case Ops.DEFINE_VAR: priority, extra = -19, u.arg
|
||||
case Ops.DEFINE_LOCAL: priority = -18
|
||||
case Ops.DEFINE_REG: priority = -17
|
||||
case Ops.CONST: priority = -10 # early consts
|
||||
case Ops.LOAD: priority = -1 # place loads early
|
||||
case Ops.STORE: priority = 1 # place stores late
|
||||
case Ops.RANGE: priority = 5 # placing RANGE is good
|
||||
|
||||
@@ -15,8 +15,9 @@ def validate_index(buf:UOp, idx:UOp, gate:UOp|None=None):
|
||||
|
||||
# TODO: validate these
|
||||
# WEBGPU has a BITCAST in the index, PTX casts pointer to long
|
||||
# VECTORIZE/GEP can't be properly modeled in z3 since it doesn't support vectors
|
||||
for x in idx.toposort() | gate.toposort():
|
||||
if x.op is Ops.BITCAST or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
|
||||
if x.op in {Ops.BITCAST, Ops.VECTORIZE, Ops.GEP} or (x.op is Ops.CAST and isinstance(x.src[0].dtype, PtrDType)): return True
|
||||
|
||||
# if all is good and IGNORE_OOB=0, validate with z3
|
||||
from tinygrad.uop.validate import validate_index_with_z3
|
||||
|
||||
@@ -163,8 +163,8 @@ gep_pushing = PatternMatcher([
|
||||
(UPat(Ops.GEP, src=(UPat(dtype=dtypes.void, name="x"),)), lambda x: x),
|
||||
# GEP in order is removed
|
||||
(UPat(Ops.GEP, name="g"), lambda g: g.src[0] if not isinstance(g.dtype, PtrDType) and g.arg == tuple(range(g.src[0].dtype.count)) else None),
|
||||
# push all GEPs through ALUs (fix arange stuff)
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, name='gep'),
|
||||
# push all GEPs through ALUs for index (TODO: remove this)
|
||||
(UPat((*GroupOp.ALU, Ops.CAST, Ops.BITCAST), name='alu').f(Ops.GEP, dtype=dtypes.index, name='gep'),
|
||||
lambda gep,alu: UOp(alu.op, alu.dtype.scalar().vec(gep.dtype.count), tuple(x.gep(gep.arg) for x in alu.src), alu.arg) \
|
||||
if not isinstance(gep.dtype, PtrDType) and not isinstance(alu.dtype, PtrDType) else None),
|
||||
# CAT can't be rendered. it's a VECTORIZE on vectors, we expand to a single VECTORIZEs with GEPs (TODO: move this later)
|
||||
|
||||
@@ -48,7 +48,9 @@ def uops_to_z3(solver, *uops: UOp) -> list[z3.ExprRef]:
|
||||
lst = list(UOp.sink(*uops).toposort(gate=lambda x: x.dtype.scalar() in dtypes.ints+(dtypes.bool, dtypes.index) or x.op is Ops.SINK))[:-1]
|
||||
z3map: dict[UOp, z3.ExprRef] = {}
|
||||
for i,u in enumerate(lst):
|
||||
new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_renderer.rewrite(u, ctx=(solver, z3map)))
|
||||
z3_rewritten = z3_renderer.rewrite(u, ctx=(solver, z3map))
|
||||
if z3_rewritten is None: raise NotImplementedError(f"{u.op} is not supported by z3")
|
||||
new_u, constraint = cast(tuple[z3.ArithRef, z3.BoolRef|None], z3_rewritten)
|
||||
if constraint is not None: solver.add(constraint)
|
||||
z3map[u] = new_u
|
||||
assert all(u in z3map for u in uops), "UOp failed to rewrite to z3!"
|
||||
|
||||
Reference in New Issue
Block a user