a bunch of todos for my boy claude

This commit is contained in:
George Hotz
2026-01-06 14:02:29 -08:00
parent add569d94c
commit c8b42edec6
3 changed files with 47 additions and 10 deletions

View File

@@ -11,7 +11,8 @@ from tinygrad.helpers import DEBUG
from tinygrad.dtype import INVERSE_DTYPES_DICT
_QDTYPES: dict[str, DType] = {
'f64': dtypes.float64, 'f32': dtypes.float32, 'f16': dtypes.float16, 'bf16': dtypes.bfloat16,
'fp8': DType.new(4, 8, "fp8", None), 'bf8': DType.new(4, 8, "bf8", None), 'fp6': DType.new(4, 6, "fp6", None), 'bf6': DType.new(4, 6, "bf6", None),
'fp8': DType.new(4, 8, "fp8", None), 'bf8': DType.new(4, 8, "bf8", None),
'fp6': DType.new(4, 6, "fp6", None), 'bf6': DType.new(4, 6, "bf6", None),
'fp4': DType.new(4, 4, "fp4", None), 'i4': DType.new(5, 4, "i4", None),
'u64': dtypes.uint64, 'u32': dtypes.uint32, 'u16': dtypes.uint16, 'u8': dtypes.uint8,
'i64': dtypes.int64, 'i32': dtypes.int32, 'i16': dtypes.int16, 'i8': dtypes.int8,
@@ -20,8 +21,10 @@ _QDTYPES: dict[str, DType] = {
'b65': DType.new(6, 65, "b65", None), 'b64': dtypes.uint64, 'b32': dtypes.uint32, 'b23': DType.new(6, 23, "b23", None), 'b16': dtypes.uint16, 'b8': dtypes.uint8, 'b4': DType.new(6, 4, "b4", None),
'u1201': DType.new(6, 1201, "u1201", None), 'u65': DType.new(6, 65, "u65", None), 'u24': DType.new(6, 24, "u24", None), 'u23': DType.new(6, 23, "u23", None),
'u6': DType.new(6, 6, "u6", None), 'u4': DType.new(6, 4, "u4", None),
# TODO: why is there i1 and u1?
'u3': DType.new(6, 3, "u3", None), 'u1': DType.new(6, 1, "u1", None),
'i65': DType.new(5, 65, "i65", None), 'i24': DType.new(5, 24, "i24", None), 'i1': DType.new(5, 1, "i1", None),
'i65': DType.new(5, 65, "i65", None), 'i24': DType.new(5, 24, "i24", None),
'i1': DType.new(5, 1, "i1", None),
'u': dtypes.uint32, 'i': dtypes.int32, 'f': dtypes.float32,
}
# Register custom dtypes for repr
@@ -36,6 +39,7 @@ _BINOPS: dict[str, Ops] = {
'<': Ops.CMPLT, '>': Ops.CMPLT, '<=': Ops.CMPLE, '>=': Ops.CMPLE,
'||': Ops.OR, '&&': Ops.AND,
}
# TODO: XOR and CMPEQ are binops. see if you can distinguish this based on types and use NEG for all
_UNOPS: dict[str, Ops] = {'-': Ops.NEG, '~': Ops.XOR, '!': Ops.CMPEQ}
# Statement types (control flow, not expressions)
@@ -103,6 +107,7 @@ def expr(s: str) -> UOp:
if s[0] == '{' and s[-1] == '}': return UOp(Ops.CAT, dtypes.void, tuple(expr(a) for a in _split(s[1:-1])))
# Typed cast: 32'U(expr)
if m := re.match(r"^(\d+)'([IUFB])\(", s):
# TODO: is this cast or bitcast? I think it's also BITCAST, but understand better
if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1: return UOp(Ops.CAST, _QDTYPES[f"{m[2].lower()}{m[1]}"], (expr(s[m.end():e]),))
# Typed constant: 32'-5I
if m := re.match(r"^(\d+)'(-?\d+)([IUFB])?$", s):
@@ -115,7 +120,10 @@ def expr(s: str) -> UOp:
if m := re.match(r"^([A-Za-z_]\w*)\(", s):
if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1:
a = _split(s[m.end():e])
return UOp(Ops.CUSTOM, dtypes.void, tuple(expr(x) for x in a) if a != [''] else (), arg=m[1])
srcs = tuple(expr(x) for x in a) if a != [''] else ()
output_dtype = dtypes.void
# TODO: set the output dtypes to reasonable things based on the input dtypes
return UOp(Ops.CUSTOM, output_dtype, srcs, arg=m[1])
# MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST
if s[:4] == 'MEM[' and (e := _match(s, 3, '[', ']')) != -1:
r, b = s[e+1:], UOp(Ops.CUSTOM, dtypes.void, (expr(s[4:e]),), arg='MEM')
@@ -129,7 +137,12 @@ def expr(s: str) -> UOp:
elif s[i] == ')': d -= 1
elif s[i] == '[': b += 1
elif s[i] == ']': b -= 1
elif s[i] == ':' and d == 0 and b == 0: return UOp(Ops.WHERE, dtypes.void, (expr(s[:q]), expr(s[q+1:i]), expr(s[i+1:])))
elif s[i] == ':' and d == 0 and b == 0:
gate, lhs, rhs = expr(s[:q]), expr(s[q+1:i]), expr(s[i+1:])
# TODO: enable this. dtypes.u1 and dtypes.i1 is probably fine too on gate
#assert gate.dtype == dtypes.bool, f"gate on where must be bool, got {gate.dtype}"
#assert lhs.dtype != dtypes.void or rhs.dtype != dtypes.void, "lhs/rhs can't both be void in WHERE"
return UOp(Ops.WHERE, dtypes.void, (gate, lhs, rhs))
# Binary ops
for ops in [('||',),('&&',),('|',),('^',),('&',),('==','!=','<>'),('<=','>=','<','>'),('<<','>>'),
('+','-'),('*','/','%'),('**',)]:
@@ -141,9 +154,17 @@ def expr(s: str) -> UOp:
flipped = op in ('>', '>=')
if flipped: lhs, rhs = rhs, lhs
tag = 'flipped' if flipped else ('<>' if op == '<>' else None)
return UOp(_BINOPS[op], dtypes.void, (lhs, rhs), tag=tag)
uop_op = _BINOPS[op]
# TODO: enable this
#assert lhs.dtype != dtypes.void or rhs.dtype != dtypes.void, f"lhs/rhs can't both be void in {uop_op}"
# TODO: check that dtypes match
output_dtype = lhs.dtype if lhs.dtype != dtypes.void else rhs.dtype
if uop_op in {Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE, Ops.CMPLT}: output_dtype = dtypes.bool
return UOp(uop_op, output_dtype, (lhs, rhs), tag=tag)
# Unary ops
if s[0] in '-~!' and len(s) > 1 and (s[0] != '!' or s[1] != '='): return UOp(_UNOPS[s[0]], dtypes.void, (expr(s[1:]),))
if s[0] in '-~!' and len(s) > 1 and (s[0] != '!' or s[1] != '='):
src = expr(s[1:])
return UOp(_UNOPS[s[0]], src.dtype, (src,))
# Slice/Index -> CUSTOMI
if '[' in s and s[-1] == ']':
d = 0
@@ -154,7 +175,9 @@ def expr(s: str) -> UOp:
b, n = s[:i], s[i+1:-1]
if '+:' in n: # Verilog [start +: width]
st, w = expr(n.split('+:', 1)[0]), expr(n.split('+:', 1)[1])
# TODO: correct dtypes here
hi = UOp(Ops.SUB, dtypes.void, (UOp(Ops.ADD, dtypes.void, (st, w)), UOp(Ops.CONST, dtypes.int32, arg=1)))
# TODO: use SHRINK instead of CUSTOMI, it's more sensible even if it doesn't perfectly match spec
return UOp(Ops.CUSTOMI, dtypes.void, (expr(b), hi, st))
if ':' in n and '?' not in n:
d = 0
@@ -172,6 +195,7 @@ def expr(s: str) -> UOp:
if s[:5] == 'eval ': return UOp(Ops.DEFINE_VAR, dtypes.void, arg=(s, None, None))
if re.match(r'^[A-Za-z_][\w.]*$', s): return UOp(Ops.DEFINE_VAR, dtypes.void, arg=(s, None, None))
# Numeric literal
# TODO: hex constants are unsigned even without U
try:
if s[:2].lower() == '0x':
m = re.match(r'0[xX]([0-9a-fA-F]+)([UuLl]*)$', s)
@@ -191,6 +215,7 @@ def expr(s: str) -> UOp:
raise ValueError(f"Cannot parse expression: {s}")
def stmt(line: str) -> Stmt|None:
# TODO: track dtypes of variables here on inputs. SCC is dtypes.bool by default. ADDR is dtypes.uint64
line = line.split('//')[0].strip().rstrip(';').rstrip('.')
if not line: return None
if line == 'break': return Break()
@@ -203,13 +228,15 @@ def stmt(line: str) -> Stmt|None:
t = t.split('[')[0] # strip array suffix like [64]
if m := re.match(r"^(\d+)'([IUFB])$", t):
dt = _QDTYPES[f"{m[2].lower()}{m[1]}"]
# TODO: track the dtypes of the created variables here
return Declare(n.strip(), dt.vec(vec_count) if vec_count > 1 else dt)
return None # unsupported declare type
for op, uop in [('+=', Ops.ADD), ('-=', Ops.SUB), ('|=', Ops.OR), ('&=', Ops.AND), ('^=', Ops.XOR), ('<<=', Ops.SHL), ('>>=', Ops.SHR)]:
if op in line:
l, r = line.split(op, 1)
lhs = expr(l)
return Assign(lhs, UOp(uop, dtypes.void, (lhs, expr(r))))
lhs, rhs = expr(l), expr(r)
# TODO: track the dtypes of the created variables here
return Assign(lhs, UOp(uop, dtypes.void, (lhs, rhs)))
if '=' in line and not any(line[:k] == p for k, p in [(3,'if '),(6,'elsif '),(4,'for ')]):
# Find leftmost assignment = (not ==, <=, >=, !=) for chained assignment support
eq = -1

View File

@@ -2,6 +2,7 @@
"""Benchmark comparing Python vs Rust RDNA3 emulators on real tinygrad kernels."""
import ctypes, time, os
from pathlib import Path
from tinygrad.helpers import getenv, Profiling
# Set AMD=1 before importing tinygrad
os.environ["AMD"] = "1"
@@ -125,6 +126,7 @@ def main():
import argparse
parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators")
parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark")
parser.add_argument("--profile", action='store_true', help="Enable profiler")
args = parser.parse_args()
rust_remu = get_rust_remu()
@@ -159,7 +161,8 @@ def main():
buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes, buf_data)
set_valid_mem_ranges(ranges)
py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations)
with Profiling(enabled=args.profile):
py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations)
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations) if rust_remu else None
if py_time:

View File

@@ -574,7 +574,14 @@ _DTYPE_ACCESSOR = {dtypes.uint8: 'u8', dtypes.int8: 'i8', dtypes.uint16: 'u16',
def _compile_pseudocode(pseudocode: str, mem_buf: UOp = MEM_BUF) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]:
ctx = Ctx(mem_buf=mem_buf)
for stmt in parse(pseudocode): _stmt(stmt, ctx)
try:
stmts = parse(pseudocode)
except AssertionError as e:
print("issue parsing")
print(pseudocode)
print(e)
raise
for stmt in stmts: _stmt(stmt, ctx)
return UOp(Ops.SINK, dtypes.void, tuple(u for _, u, _ in ctx.outputs) or ()), [(n, d) for n, _, d in ctx.outputs], INPUT_VARS, ctx.mem_stores
def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[str, UOp], mem_stores: list[UOp]):