mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
197 lines
7.6 KiB
Python
197 lines
7.6 KiB
Python
from tinygrad import Tensor, dtypes, Context, getenv, UOp, fetch
|
|
from tinygrad.uop.ops import Ops, PatternMatcher, UPat
|
|
from tinygrad.uop.symbolic import symbolic
|
|
from tinygrad.codegen import Renderer
|
|
from tinygrad.codegen.opt import Opt, OptOps
|
|
|
|
# ************************* implementation of the problem ************************
|
|
|
|
def myhash(a: Tensor) -> Tensor:
|
|
a = (a + 0x7ED55D16) + (a << 12)
|
|
a = (a ^ 0xC761C23C) ^ (a >> 19)
|
|
a = (a + 0x165667B1) + (a << 5)
|
|
a = (a + 0xD3A2646C) ^ (a << 9)
|
|
a = (a + 0xFD7046C5) + (a << 3)
|
|
a = (a ^ 0xB55A4F09) ^ (a >> 16)
|
|
return a
|
|
|
|
def select_with_where_tree(values: Tensor, relative_idx: Tensor) -> Tensor:
|
|
n = values.shape[0]
|
|
if n == 1: return values[0].expand(relative_idx.shape)
|
|
|
|
mid = n // 2
|
|
left = select_with_where_tree(values[:mid], relative_idx)
|
|
right = select_with_where_tree(values[mid:], relative_idx - mid)
|
|
|
|
go_left = relative_idx < mid
|
|
return go_left.where(left, right)
|
|
|
|
def tree_traversal(forest: Tensor, val: Tensor, height: int, rounds: int, where_tree_threshold=3) -> Tensor:
|
|
# All walkers start at idx=0
|
|
idx = Tensor.zeros(val.shape, device=val.device, dtype=dtypes.uint32)
|
|
|
|
for r in range(rounds):
|
|
level = r % (height + 1)
|
|
level_start = (1 << level) - 1
|
|
level_size = 1 << level
|
|
|
|
if level == 0:
|
|
# At root (level 0), all walkers are at idx=0
|
|
# No gather needed, just broadcast the root value
|
|
node_val = forest[0].expand(val.shape)
|
|
idx = idx * 0 # Reset to 0
|
|
elif level <= where_tree_threshold:
|
|
# Small level: use where-tree
|
|
level_values = forest[level_start : level_start + level_size]
|
|
relative_idx = (idx - level_start)
|
|
node_val = select_with_where_tree(level_values, relative_idx)
|
|
else:
|
|
# Large level: use gather
|
|
node_val = forest.gather(0, idx)
|
|
|
|
val = myhash(val ^ node_val)
|
|
idx = (idx << 1) + (1 + (val & 1))
|
|
|
|
# No wrap check needed! At round 10 (level becomes 0), we reset idx above.
|
|
|
|
return val.contiguous(arg=(Opt(OptOps.UPCAST, 0, 8),))
|
|
|
|
# ************************* renderer for VLIW machine *************************
|
|
|
|
def loop_unrolling(sink:UOp):
|
|
rng = [x for x in sink.toposort() if x.op is Ops.RANGE]
|
|
if len(rng) == 0: return None
|
|
print(f"unrolling loop with size {rng[0].vmax+1}")
|
|
unrolled_sinks = [sink.substitute({rng[0]:rng[0].const_like(i)}).src[0] for i in range(rng[0].vmax+1)]
|
|
return UOp.sink(*unrolled_sinks, arg=sink.arg)
|
|
|
|
global_addrs = []
|
|
vliw_prepare = PatternMatcher([
|
|
# loop unrolling (should be a part of tinygrad)
|
|
(UPat(Ops.SINK, name="sink"), loop_unrolling),
|
|
# cast is fake
|
|
(UPat(Ops.CAST, name="c"), lambda c: c.src[0]),
|
|
# rewrites to hardcode the addresses in memory
|
|
(UPat(Ops.PARAM, name="dg"), lambda dg: UOp.const(dtypes.uint, global_addrs[dg.arg])),
|
|
# INDEX is just plus
|
|
(UPat(Ops.INDEX, name="i"), lambda i: i.src[0]+i.src[1]),
|
|
])+symbolic
|
|
|
|
class VLIWRenderer(Renderer):
|
|
has_local = False # TODO: this should be the default / cleaned up
|
|
# this says this backend supports MULACC + more. decompositions uses this
|
|
code_for_op: dict = {Ops.MULACC: None, Ops.ADD: "+", Ops.MUL: "*",
|
|
Ops.XOR: "^", Ops.AND: "&", Ops.OR: "|",
|
|
Ops.SHL: "<<", Ops.SHR: ">>", Ops.CMPLT: "<"}
|
|
# this matcher runs while still in graph form
|
|
pre_matcher = vliw_prepare
|
|
|
|
def render(self, uops:list[UOp]):
|
|
|
|
# TODO: this is a minimal renderer. for low cycle count, make it good
|
|
# to get speed, you need to add VLIW packing
|
|
# to get under 1536 regs, you need to add a register allocator
|
|
# we left the fun parts to you
|
|
|
|
print(f"rendering with {len(uops)} uops")
|
|
reg, inst = 0, []
|
|
r: dict[UOp, int] = {}
|
|
for u in uops:
|
|
assert u.dtype.count in (1,8), "dtype count must be 1 or 8"
|
|
|
|
# dumb register allocator
|
|
if u.op not in {Ops.STORE, Ops.SINK, Ops.GEP}:
|
|
r[u] = reg
|
|
reg += u.dtype.count
|
|
|
|
# render UOps to instructions
|
|
match u.op:
|
|
case Ops.SINK:
|
|
inst.append({"flow": [("halt",)]})
|
|
case Ops.CONST:
|
|
inst.append({"load": [("const", r[u], u.arg)]})
|
|
case Ops.GEP:
|
|
# a GEP is just an alias to a special register in the vector
|
|
r[u] = r[u.src[0]] + u.arg[0]
|
|
case Ops.VECTORIZE:
|
|
if all(s == u.src[0] for s in u.src):
|
|
# if all sources are the same, we can broadcast
|
|
inst.append({"valu": [("vbroadcast", r[u], r[u.src[0]])]})
|
|
else:
|
|
# this is a copy into a contiguous chunk of registers
|
|
inst.extend({"flow": [("add_imm", r[u]+i, r[s], 0)]} for i,s in enumerate(u.src) if r[s] != r[u]+i)
|
|
case Ops.LOAD:
|
|
op = "vload" if u.dtype.count > 1 else "load"
|
|
inst.append({"load": [(op, r[u], r[u.src[0]])]})
|
|
case Ops.STORE:
|
|
op = "vstore" if u.src[1].dtype.count > 1 else "store"
|
|
inst.append({"store": [(op, r[u.src[0]], r[u.src[1]])]})
|
|
case Ops.MULACC:
|
|
assert u.dtype.count == 8
|
|
inst.append({"valu": [("multiply_add", r[u], r[u.src[0]], r[u.src[1]], r[u.src[2]])]})
|
|
case Ops.WHERE:
|
|
assert u.dtype.count == 8
|
|
inst.append({"flow": [("vselect", r[u], r[u.src[0]], r[u.src[1]], r[u.src[2]])]})
|
|
case _ if u.op in self.code_for_op:
|
|
cat = "valu" if u.dtype.count > 1 else "alu"
|
|
inst.append({cat: [(self.code_for_op[u.op], r[u], r[u.src[0]], r[u.src[1]])]})
|
|
case _:
|
|
raise NotImplementedError(f"unhandled op {u.op}")
|
|
return repr(inst)
|
|
|
|
# ************************* test and render *************************
|
|
|
|
import sys, types
|
|
PROBLEM_URL = "https://raw.githubusercontent.com/anthropics/original_performance_takehome/refs/heads/main/tests/frozen_problem.py"
|
|
sys.modules["problem"] = problem = types.ModuleType("problem")
|
|
exec(fetch(PROBLEM_URL).read_text(), problem.__dict__)
|
|
|
|
if __name__ == "__main__":
|
|
batch_size = getenv("BS", 256)
|
|
height = 10
|
|
rounds = getenv("ROUNDS", 16)
|
|
|
|
# build problem
|
|
tree = problem.Tree.generate(height)
|
|
inp = problem.Input.generate(tree, batch_size, rounds)
|
|
mem = problem.build_mem_image(tree, inp)
|
|
global_addrs.extend([mem[6], mem[6], mem[4]]) # output, input, forest
|
|
|
|
# *** verify the kernel in tinygrad compared to reference ***
|
|
|
|
forest_t = Tensor(tree.values, dtype=dtypes.uint32)
|
|
val_t = Tensor(inp.values, dtype=dtypes.uint32)
|
|
|
|
if getenv("VERIFY", 1):
|
|
# verify on normal tinygrad device
|
|
with Context(PCONTIG=2):
|
|
out = tree_traversal(forest_t, val_t, height, rounds)
|
|
val_out = out.tolist()
|
|
problem.reference_kernel(tree, inp)
|
|
assert val_out == inp.values
|
|
print("verification passed")
|
|
|
|
# *** render to device ***
|
|
|
|
from tinygrad.codegen import get_program
|
|
with Context(PCONTIG=2, DEVECTORIZE=2, SPEC=0):
|
|
out = tree_traversal(forest_t, val_t, height, rounds)
|
|
sink = out.schedule()[-1].ast
|
|
prg = get_program(sink, VLIWRenderer())
|
|
|
|
# *** run on Machine and compare ***
|
|
|
|
# NOTE: the scratch size needs to be reduced to 1536 when you have a register allocator
|
|
src = eval(prg.src)
|
|
max_regs = max(t[1] for instr in src for v in instr.values() for t in v if len(t) > 1) + 8
|
|
print(f"{max_regs:5d} regs used" + ("" if max_regs <= 1536 else " <-- WARNING: TOO MANY REGISTERS, MUST BE <= 1536"))
|
|
machine = problem.Machine(mem, src, problem.DebugInfo(scratch_map={}), n_cores=1, trace=False, scratch_size=max_regs)
|
|
machine.run()
|
|
print(f"ran for {machine.cycle:5d} cycles" + ("" if machine.cycle <= 1363 else " <-- EVEN CLAUDE GOT 1363"))
|
|
|
|
# compare to reference
|
|
ref_mem = mem.copy()
|
|
for _ in problem.reference_kernel2(ref_mem, {}): pass
|
|
assert machine.mem[mem[6]:mem[6]+mem[2]] == ref_mem[mem[6]:mem[6]+mem[2]]
|
|
print("compare passed!")
|