from tinygrad import Tensor, dtypes, Context, getenv, UOp, fetch from tinygrad.uop.ops import Ops, PatternMatcher, UPat from tinygrad.uop.symbolic import symbolic from tinygrad.codegen import Renderer from tinygrad.codegen.opt import Opt, OptOps # ************************* implementation of the problem ************************ def myhash(a: Tensor) -> Tensor: a = (a + 0x7ED55D16) + (a << 12) a = (a ^ 0xC761C23C) ^ (a >> 19) a = (a + 0x165667B1) + (a << 5) a = (a + 0xD3A2646C) ^ (a << 9) a = (a + 0xFD7046C5) + (a << 3) a = (a ^ 0xB55A4F09) ^ (a >> 16) return a def select_with_where_tree(values: Tensor, relative_idx: Tensor) -> Tensor: n = values.shape[0] if n == 1: return values[0].expand(relative_idx.shape) mid = n // 2 left = select_with_where_tree(values[:mid], relative_idx) right = select_with_where_tree(values[mid:], relative_idx - mid) go_left = relative_idx < mid return go_left.where(left, right) def tree_traversal(forest: Tensor, val: Tensor, height: int, rounds: int, where_tree_threshold=3) -> Tensor: # All walkers start at idx=0 idx = Tensor.zeros(val.shape, device=val.device, dtype=dtypes.uint32) for r in range(rounds): level = r % (height + 1) level_start = (1 << level) - 1 level_size = 1 << level if level == 0: # At root (level 0), all walkers are at idx=0 # No gather needed, just broadcast the root value node_val = forest[0].expand(val.shape) idx = idx * 0 # Reset to 0 elif level <= where_tree_threshold: # Small level: use where-tree level_values = forest[level_start : level_start + level_size] relative_idx = (idx - level_start) node_val = select_with_where_tree(level_values, relative_idx) else: # Large level: use gather node_val = forest.gather(0, idx) val = myhash(val ^ node_val) idx = (idx << 1) + (1 + (val & 1)) # No wrap check needed! At round 10 (level becomes 0), we reset idx above. return val.contiguous(arg=(Opt(OptOps.UPCAST, 0, 8),)) # ************************* renderer for VLIW machine ************************* def loop_unrolling(sink:UOp): rng = [x for x in sink.toposort() if x.op is Ops.RANGE] if len(rng) == 0: return None print(f"unrolling loop with size {rng[0].vmax+1}") unrolled_sinks = [sink.substitute({rng[0]:rng[0].const_like(i)}).src[0] for i in range(rng[0].vmax+1)] return UOp.sink(*unrolled_sinks, arg=sink.arg) global_addrs = [] vliw_prepare = PatternMatcher([ # loop unrolling (should be a part of tinygrad) (UPat(Ops.SINK, name="sink"), loop_unrolling), # cast is fake (UPat(Ops.CAST, name="c"), lambda c: c.src[0]), # rewrites to hardcode the addresses in memory (UPat(Ops.PARAM, name="dg"), lambda dg: UOp.const(dtypes.uint, global_addrs[dg.arg])), # INDEX is just plus (UPat(Ops.INDEX, name="i"), lambda i: i.src[0]+i.src[1]), ])+symbolic class VLIWRenderer(Renderer): has_local = False # TODO: this should be the default / cleaned up # this says this backend supports MULACC + more. decompositions uses this code_for_op: dict = {Ops.MULACC: None, Ops.ADD: "+", Ops.MUL: "*", Ops.XOR: "^", Ops.AND: "&", Ops.OR: "|", Ops.SHL: "<<", Ops.SHR: ">>", Ops.CMPLT: "<"} # this matcher runs while still in graph form pre_matcher = vliw_prepare def render(self, uops:list[UOp]): # TODO: this is a minimal renderer. for low cycle count, make it good # to get speed, you need to add VLIW packing # to get under 1536 regs, you need to add a register allocator # we left the fun parts to you print(f"rendering with {len(uops)} uops") reg, inst = 0, [] r: dict[UOp, int] = {} for u in uops: assert u.dtype.count in (1,8), "dtype count must be 1 or 8" # dumb register allocator if u.op not in {Ops.STORE, Ops.SINK, Ops.GEP}: r[u] = reg reg += u.dtype.count # render UOps to instructions match u.op: case Ops.SINK: inst.append({"flow": [("halt",)]}) case Ops.CONST: inst.append({"load": [("const", r[u], u.arg)]}) case Ops.GEP: # a GEP is just an alias to a special register in the vector r[u] = r[u.src[0]] + u.arg[0] case Ops.VECTORIZE: if all(s == u.src[0] for s in u.src): # if all sources are the same, we can broadcast inst.append({"valu": [("vbroadcast", r[u], r[u.src[0]])]}) else: # this is a copy into a contiguous chunk of registers inst.extend({"flow": [("add_imm", r[u]+i, r[s], 0)]} for i,s in enumerate(u.src) if r[s] != r[u]+i) case Ops.LOAD: op = "vload" if u.dtype.count > 1 else "load" inst.append({"load": [(op, r[u], r[u.src[0]])]}) case Ops.STORE: op = "vstore" if u.src[1].dtype.count > 1 else "store" inst.append({"store": [(op, r[u.src[0]], r[u.src[1]])]}) case Ops.MULACC: assert u.dtype.count == 8 inst.append({"valu": [("multiply_add", r[u], r[u.src[0]], r[u.src[1]], r[u.src[2]])]}) case Ops.WHERE: assert u.dtype.count == 8 inst.append({"flow": [("vselect", r[u], r[u.src[0]], r[u.src[1]], r[u.src[2]])]}) case _ if u.op in self.code_for_op: cat = "valu" if u.dtype.count > 1 else "alu" inst.append({cat: [(self.code_for_op[u.op], r[u], r[u.src[0]], r[u.src[1]])]}) case _: raise NotImplementedError(f"unhandled op {u.op}") return repr(inst) # ************************* test and render ************************* import sys, types PROBLEM_URL = "https://raw.githubusercontent.com/anthropics/original_performance_takehome/refs/heads/main/tests/frozen_problem.py" sys.modules["problem"] = problem = types.ModuleType("problem") exec(fetch(PROBLEM_URL).read_text(), problem.__dict__) if __name__ == "__main__": batch_size = getenv("BS", 256) height = 10 rounds = getenv("ROUNDS", 16) # build problem tree = problem.Tree.generate(height) inp = problem.Input.generate(tree, batch_size, rounds) mem = problem.build_mem_image(tree, inp) global_addrs.extend([mem[6], mem[6], mem[4]]) # output, input, forest # *** verify the kernel in tinygrad compared to reference *** forest_t = Tensor(tree.values, dtype=dtypes.uint32) val_t = Tensor(inp.values, dtype=dtypes.uint32) if getenv("VERIFY", 1): # verify on normal tinygrad device with Context(PCONTIG=2): out = tree_traversal(forest_t, val_t, height, rounds) val_out = out.tolist() problem.reference_kernel(tree, inp) assert val_out == inp.values print("verification passed") # *** render to device *** from tinygrad.codegen import get_program with Context(PCONTIG=2, DEVECTORIZE=2, SPEC=0): out = tree_traversal(forest_t, val_t, height, rounds) sink = out.schedule()[-1].ast prg = get_program(sink, VLIWRenderer()) # *** run on Machine and compare *** # NOTE: the scratch size needs to be reduced to 1536 when you have a register allocator src = eval(prg.src) max_regs = max(t[1] for instr in src for v in instr.values() for t in v if len(t) > 1) + 8 print(f"{max_regs:5d} regs used" + ("" if max_regs <= 1536 else " <-- WARNING: TOO MANY REGISTERS, MUST BE <= 1536")) machine = problem.Machine(mem, src, problem.DebugInfo(scratch_map={}), n_cores=1, trace=False, scratch_size=max_regs) machine.run() print(f"ran for {machine.cycle:5d} cycles" + ("" if machine.cycle <= 1363 else " <-- EVEN CLAUDE GOT 1363")) # compare to reference ref_mem = mem.copy() for _ in problem.reference_kernel2(ref_mem, {}): pass assert machine.mem[mem[6]:mem[6]+mem[2]] == ref_mem[mem[6]:mem[6]+mem[2]] print("compare passed!")