mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
a bunch of todos for my boy claude
This commit is contained in:
@@ -11,7 +11,8 @@ from tinygrad.helpers import DEBUG
|
||||
from tinygrad.dtype import INVERSE_DTYPES_DICT
|
||||
_QDTYPES: dict[str, DType] = {
|
||||
'f64': dtypes.float64, 'f32': dtypes.float32, 'f16': dtypes.float16, 'bf16': dtypes.bfloat16,
|
||||
'fp8': DType.new(4, 8, "fp8", None), 'bf8': DType.new(4, 8, "bf8", None), 'fp6': DType.new(4, 6, "fp6", None), 'bf6': DType.new(4, 6, "bf6", None),
|
||||
'fp8': DType.new(4, 8, "fp8", None), 'bf8': DType.new(4, 8, "bf8", None),
|
||||
'fp6': DType.new(4, 6, "fp6", None), 'bf6': DType.new(4, 6, "bf6", None),
|
||||
'fp4': DType.new(4, 4, "fp4", None), 'i4': DType.new(5, 4, "i4", None),
|
||||
'u64': dtypes.uint64, 'u32': dtypes.uint32, 'u16': dtypes.uint16, 'u8': dtypes.uint8,
|
||||
'i64': dtypes.int64, 'i32': dtypes.int32, 'i16': dtypes.int16, 'i8': dtypes.int8,
|
||||
@@ -20,8 +21,10 @@ _QDTYPES: dict[str, DType] = {
|
||||
'b65': DType.new(6, 65, "b65", None), 'b64': dtypes.uint64, 'b32': dtypes.uint32, 'b23': DType.new(6, 23, "b23", None), 'b16': dtypes.uint16, 'b8': dtypes.uint8, 'b4': DType.new(6, 4, "b4", None),
|
||||
'u1201': DType.new(6, 1201, "u1201", None), 'u65': DType.new(6, 65, "u65", None), 'u24': DType.new(6, 24, "u24", None), 'u23': DType.new(6, 23, "u23", None),
|
||||
'u6': DType.new(6, 6, "u6", None), 'u4': DType.new(6, 4, "u4", None),
|
||||
# TODO: why is there i1 and u1?
|
||||
'u3': DType.new(6, 3, "u3", None), 'u1': DType.new(6, 1, "u1", None),
|
||||
'i65': DType.new(5, 65, "i65", None), 'i24': DType.new(5, 24, "i24", None), 'i1': DType.new(5, 1, "i1", None),
|
||||
'i65': DType.new(5, 65, "i65", None), 'i24': DType.new(5, 24, "i24", None),
|
||||
'i1': DType.new(5, 1, "i1", None),
|
||||
'u': dtypes.uint32, 'i': dtypes.int32, 'f': dtypes.float32,
|
||||
}
|
||||
# Register custom dtypes for repr
|
||||
@@ -36,6 +39,7 @@ _BINOPS: dict[str, Ops] = {
|
||||
'<': Ops.CMPLT, '>': Ops.CMPLT, '<=': Ops.CMPLE, '>=': Ops.CMPLE,
|
||||
'||': Ops.OR, '&&': Ops.AND,
|
||||
}
|
||||
# TODO: XOR and CMPEQ are binops. see if you can distinguish this based on types and use NEG for all
|
||||
_UNOPS: dict[str, Ops] = {'-': Ops.NEG, '~': Ops.XOR, '!': Ops.CMPEQ}
|
||||
|
||||
# Statement types (control flow, not expressions)
|
||||
@@ -103,6 +107,7 @@ def expr(s: str) -> UOp:
|
||||
if s[0] == '{' and s[-1] == '}': return UOp(Ops.CAT, dtypes.void, tuple(expr(a) for a in _split(s[1:-1])))
|
||||
# Typed cast: 32'U(expr)
|
||||
if m := re.match(r"^(\d+)'([IUFB])\(", s):
|
||||
# TODO: is this cast or bitcast? I think it's also BITCAST, but understand better
|
||||
if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1: return UOp(Ops.CAST, _QDTYPES[f"{m[2].lower()}{m[1]}"], (expr(s[m.end():e]),))
|
||||
# Typed constant: 32'-5I
|
||||
if m := re.match(r"^(\d+)'(-?\d+)([IUFB])?$", s):
|
||||
@@ -115,7 +120,10 @@ def expr(s: str) -> UOp:
|
||||
if m := re.match(r"^([A-Za-z_]\w*)\(", s):
|
||||
if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1:
|
||||
a = _split(s[m.end():e])
|
||||
return UOp(Ops.CUSTOM, dtypes.void, tuple(expr(x) for x in a) if a != [''] else (), arg=m[1])
|
||||
srcs = tuple(expr(x) for x in a) if a != [''] else ()
|
||||
output_dtype = dtypes.void
|
||||
# TODO: set the output dtypes to reasonable things based on the input dtypes
|
||||
return UOp(Ops.CUSTOM, output_dtype, srcs, arg=m[1])
|
||||
# MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST
|
||||
if s[:4] == 'MEM[' and (e := _match(s, 3, '[', ']')) != -1:
|
||||
r, b = s[e+1:], UOp(Ops.CUSTOM, dtypes.void, (expr(s[4:e]),), arg='MEM')
|
||||
@@ -129,7 +137,12 @@ def expr(s: str) -> UOp:
|
||||
elif s[i] == ')': d -= 1
|
||||
elif s[i] == '[': b += 1
|
||||
elif s[i] == ']': b -= 1
|
||||
elif s[i] == ':' and d == 0 and b == 0: return UOp(Ops.WHERE, dtypes.void, (expr(s[:q]), expr(s[q+1:i]), expr(s[i+1:])))
|
||||
elif s[i] == ':' and d == 0 and b == 0:
|
||||
gate, lhs, rhs = expr(s[:q]), expr(s[q+1:i]), expr(s[i+1:])
|
||||
# TODO: enable this. dtypes.u1 and dtypes.i1 is probably fine too on gate
|
||||
#assert gate.dtype == dtypes.bool, f"gate on where must be bool, got {gate.dtype}"
|
||||
#assert lhs.dtype != dtypes.void or rhs.dtype != dtypes.void, "lhs/rhs can't both be void in WHERE"
|
||||
return UOp(Ops.WHERE, dtypes.void, (gate, lhs, rhs))
|
||||
# Binary ops
|
||||
for ops in [('||',),('&&',),('|',),('^',),('&',),('==','!=','<>'),('<=','>=','<','>'),('<<','>>'),
|
||||
('+','-'),('*','/','%'),('**',)]:
|
||||
@@ -141,9 +154,17 @@ def expr(s: str) -> UOp:
|
||||
flipped = op in ('>', '>=')
|
||||
if flipped: lhs, rhs = rhs, lhs
|
||||
tag = 'flipped' if flipped else ('<>' if op == '<>' else None)
|
||||
return UOp(_BINOPS[op], dtypes.void, (lhs, rhs), tag=tag)
|
||||
uop_op = _BINOPS[op]
|
||||
# TODO: enable this
|
||||
#assert lhs.dtype != dtypes.void or rhs.dtype != dtypes.void, f"lhs/rhs can't both be void in {uop_op}"
|
||||
# TODO: check that dtypes match
|
||||
output_dtype = lhs.dtype if lhs.dtype != dtypes.void else rhs.dtype
|
||||
if uop_op in {Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE, Ops.CMPLT}: output_dtype = dtypes.bool
|
||||
return UOp(uop_op, output_dtype, (lhs, rhs), tag=tag)
|
||||
# Unary ops
|
||||
if s[0] in '-~!' and len(s) > 1 and (s[0] != '!' or s[1] != '='): return UOp(_UNOPS[s[0]], dtypes.void, (expr(s[1:]),))
|
||||
if s[0] in '-~!' and len(s) > 1 and (s[0] != '!' or s[1] != '='):
|
||||
src = expr(s[1:])
|
||||
return UOp(_UNOPS[s[0]], src.dtype, (src,))
|
||||
# Slice/Index -> CUSTOMI
|
||||
if '[' in s and s[-1] == ']':
|
||||
d = 0
|
||||
@@ -154,7 +175,9 @@ def expr(s: str) -> UOp:
|
||||
b, n = s[:i], s[i+1:-1]
|
||||
if '+:' in n: # Verilog [start +: width]
|
||||
st, w = expr(n.split('+:', 1)[0]), expr(n.split('+:', 1)[1])
|
||||
# TODO: correct dtypes here
|
||||
hi = UOp(Ops.SUB, dtypes.void, (UOp(Ops.ADD, dtypes.void, (st, w)), UOp(Ops.CONST, dtypes.int32, arg=1)))
|
||||
# TODO: use SHRINK instead of CUSTOMI, it's more sensible even if it doesn't perfectly match spec
|
||||
return UOp(Ops.CUSTOMI, dtypes.void, (expr(b), hi, st))
|
||||
if ':' in n and '?' not in n:
|
||||
d = 0
|
||||
@@ -172,6 +195,7 @@ def expr(s: str) -> UOp:
|
||||
if s[:5] == 'eval ': return UOp(Ops.DEFINE_VAR, dtypes.void, arg=(s, None, None))
|
||||
if re.match(r'^[A-Za-z_][\w.]*$', s): return UOp(Ops.DEFINE_VAR, dtypes.void, arg=(s, None, None))
|
||||
# Numeric literal
|
||||
# TODO: hex constants are unsigned even without U
|
||||
try:
|
||||
if s[:2].lower() == '0x':
|
||||
m = re.match(r'0[xX]([0-9a-fA-F]+)([UuLl]*)$', s)
|
||||
@@ -191,6 +215,7 @@ def expr(s: str) -> UOp:
|
||||
raise ValueError(f"Cannot parse expression: {s}")
|
||||
|
||||
def stmt(line: str) -> Stmt|None:
|
||||
# TODO: track dtypes of variables here on inputs. SCC is dtypes.bool by default. ADDR is dtypes.uint64
|
||||
line = line.split('//')[0].strip().rstrip(';').rstrip('.')
|
||||
if not line: return None
|
||||
if line == 'break': return Break()
|
||||
@@ -203,13 +228,15 @@ def stmt(line: str) -> Stmt|None:
|
||||
t = t.split('[')[0] # strip array suffix like [64]
|
||||
if m := re.match(r"^(\d+)'([IUFB])$", t):
|
||||
dt = _QDTYPES[f"{m[2].lower()}{m[1]}"]
|
||||
# TODO: track the dtypes of the created variables here
|
||||
return Declare(n.strip(), dt.vec(vec_count) if vec_count > 1 else dt)
|
||||
return None # unsupported declare type
|
||||
for op, uop in [('+=', Ops.ADD), ('-=', Ops.SUB), ('|=', Ops.OR), ('&=', Ops.AND), ('^=', Ops.XOR), ('<<=', Ops.SHL), ('>>=', Ops.SHR)]:
|
||||
if op in line:
|
||||
l, r = line.split(op, 1)
|
||||
lhs = expr(l)
|
||||
return Assign(lhs, UOp(uop, dtypes.void, (lhs, expr(r))))
|
||||
lhs, rhs = expr(l), expr(r)
|
||||
# TODO: track the dtypes of the created variables here
|
||||
return Assign(lhs, UOp(uop, dtypes.void, (lhs, rhs)))
|
||||
if '=' in line and not any(line[:k] == p for k, p in [(3,'if '),(6,'elsif '),(4,'for ')]):
|
||||
# Find leftmost assignment = (not ==, <=, >=, !=) for chained assignment support
|
||||
eq = -1
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
"""Benchmark comparing Python vs Rust RDNA3 emulators on real tinygrad kernels."""
|
||||
import ctypes, time, os
|
||||
from pathlib import Path
|
||||
from tinygrad.helpers import getenv, Profiling
|
||||
|
||||
# Set AMD=1 before importing tinygrad
|
||||
os.environ["AMD"] = "1"
|
||||
@@ -125,6 +126,7 @@ def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators")
|
||||
parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark")
|
||||
parser.add_argument("--profile", action='store_true', help="Enable profiler")
|
||||
args = parser.parse_args()
|
||||
|
||||
rust_remu = get_rust_remu()
|
||||
@@ -159,7 +161,8 @@ def main():
|
||||
buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes, buf_data)
|
||||
set_valid_mem_ranges(ranges)
|
||||
|
||||
py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations)
|
||||
with Profiling(enabled=args.profile):
|
||||
py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations)
|
||||
rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations) if rust_remu else None
|
||||
|
||||
if py_time:
|
||||
|
||||
@@ -574,7 +574,14 @@ _DTYPE_ACCESSOR = {dtypes.uint8: 'u8', dtypes.int8: 'i8', dtypes.uint16: 'u16',
|
||||
|
||||
def _compile_pseudocode(pseudocode: str, mem_buf: UOp = MEM_BUF) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]:
|
||||
ctx = Ctx(mem_buf=mem_buf)
|
||||
for stmt in parse(pseudocode): _stmt(stmt, ctx)
|
||||
try:
|
||||
stmts = parse(pseudocode)
|
||||
except AssertionError as e:
|
||||
print("issue parsing")
|
||||
print(pseudocode)
|
||||
print(e)
|
||||
raise
|
||||
for stmt in stmts: _stmt(stmt, ctx)
|
||||
return UOp(Ops.SINK, dtypes.void, tuple(u for _, u, _ in ctx.outputs) or ()), [(n, d) for n, _, d in ctx.outputs], INPUT_VARS, ctx.mem_stores
|
||||
|
||||
def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[str, UOp], mem_stores: list[UOp]):
|
||||
|
||||
Reference in New Issue
Block a user