diff --git a/extra/assembly/amd/pcode_parse.py b/extra/assembly/amd/pcode_parse.py index ea37eef7de..5f22058c7c 100644 --- a/extra/assembly/amd/pcode_parse.py +++ b/extra/assembly/amd/pcode_parse.py @@ -11,7 +11,8 @@ from tinygrad.helpers import DEBUG from tinygrad.dtype import INVERSE_DTYPES_DICT _QDTYPES: dict[str, DType] = { 'f64': dtypes.float64, 'f32': dtypes.float32, 'f16': dtypes.float16, 'bf16': dtypes.bfloat16, - 'fp8': DType.new(4, 8, "fp8", None), 'bf8': DType.new(4, 8, "bf8", None), 'fp6': DType.new(4, 6, "fp6", None), 'bf6': DType.new(4, 6, "bf6", None), + 'fp8': DType.new(4, 8, "fp8", None), 'bf8': DType.new(4, 8, "bf8", None), + 'fp6': DType.new(4, 6, "fp6", None), 'bf6': DType.new(4, 6, "bf6", None), 'fp4': DType.new(4, 4, "fp4", None), 'i4': DType.new(5, 4, "i4", None), 'u64': dtypes.uint64, 'u32': dtypes.uint32, 'u16': dtypes.uint16, 'u8': dtypes.uint8, 'i64': dtypes.int64, 'i32': dtypes.int32, 'i16': dtypes.int16, 'i8': dtypes.int8, @@ -20,8 +21,10 @@ _QDTYPES: dict[str, DType] = { 'b65': DType.new(6, 65, "b65", None), 'b64': dtypes.uint64, 'b32': dtypes.uint32, 'b23': DType.new(6, 23, "b23", None), 'b16': dtypes.uint16, 'b8': dtypes.uint8, 'b4': DType.new(6, 4, "b4", None), 'u1201': DType.new(6, 1201, "u1201", None), 'u65': DType.new(6, 65, "u65", None), 'u24': DType.new(6, 24, "u24", None), 'u23': DType.new(6, 23, "u23", None), 'u6': DType.new(6, 6, "u6", None), 'u4': DType.new(6, 4, "u4", None), + # TODO: why is there i1 and u1? 'u3': DType.new(6, 3, "u3", None), 'u1': DType.new(6, 1, "u1", None), - 'i65': DType.new(5, 65, "i65", None), 'i24': DType.new(5, 24, "i24", None), 'i1': DType.new(5, 1, "i1", None), + 'i65': DType.new(5, 65, "i65", None), 'i24': DType.new(5, 24, "i24", None), + 'i1': DType.new(5, 1, "i1", None), 'u': dtypes.uint32, 'i': dtypes.int32, 'f': dtypes.float32, } # Register custom dtypes for repr @@ -36,6 +39,7 @@ _BINOPS: dict[str, Ops] = { '<': Ops.CMPLT, '>': Ops.CMPLT, '<=': Ops.CMPLE, '>=': Ops.CMPLE, '||': Ops.OR, '&&': Ops.AND, } +# TODO: XOR and CMPEQ are binops. see if you can distinguish this based on types and use NEG for all _UNOPS: dict[str, Ops] = {'-': Ops.NEG, '~': Ops.XOR, '!': Ops.CMPEQ} # Statement types (control flow, not expressions) @@ -103,6 +107,7 @@ def expr(s: str) -> UOp: if s[0] == '{' and s[-1] == '}': return UOp(Ops.CAT, dtypes.void, tuple(expr(a) for a in _split(s[1:-1]))) # Typed cast: 32'U(expr) if m := re.match(r"^(\d+)'([IUFB])\(", s): + # TODO: is this cast or bitcast? I think it's also BITCAST, but understand better if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1: return UOp(Ops.CAST, _QDTYPES[f"{m[2].lower()}{m[1]}"], (expr(s[m.end():e]),)) # Typed constant: 32'-5I if m := re.match(r"^(\d+)'(-?\d+)([IUFB])?$", s): @@ -115,7 +120,10 @@ def expr(s: str) -> UOp: if m := re.match(r"^([A-Za-z_]\w*)\(", s): if (e := _match(s, m.end()-1, '(', ')')) == len(s)-1: a = _split(s[m.end():e]) - return UOp(Ops.CUSTOM, dtypes.void, tuple(expr(x) for x in a) if a != [''] else (), arg=m[1]) + srcs = tuple(expr(x) for x in a) if a != [''] else () + output_dtype = dtypes.void + # TODO: set the output dtypes to reasonable things based on the input dtypes + return UOp(Ops.CUSTOM, output_dtype, srcs, arg=m[1]) # MEM[addr] -> CUSTOM('MEM', addr), MEM[addr].type -> BITCAST if s[:4] == 'MEM[' and (e := _match(s, 3, '[', ']')) != -1: r, b = s[e+1:], UOp(Ops.CUSTOM, dtypes.void, (expr(s[4:e]),), arg='MEM') @@ -129,7 +137,12 @@ def expr(s: str) -> UOp: elif s[i] == ')': d -= 1 elif s[i] == '[': b += 1 elif s[i] == ']': b -= 1 - elif s[i] == ':' and d == 0 and b == 0: return UOp(Ops.WHERE, dtypes.void, (expr(s[:q]), expr(s[q+1:i]), expr(s[i+1:]))) + elif s[i] == ':' and d == 0 and b == 0: + gate, lhs, rhs = expr(s[:q]), expr(s[q+1:i]), expr(s[i+1:]) + # TODO: enable this. dtypes.u1 and dtypes.i1 is probably fine too on gate + #assert gate.dtype == dtypes.bool, f"gate on where must be bool, got {gate.dtype}" + #assert lhs.dtype != dtypes.void or rhs.dtype != dtypes.void, "lhs/rhs can't both be void in WHERE" + return UOp(Ops.WHERE, dtypes.void, (gate, lhs, rhs)) # Binary ops for ops in [('||',),('&&',),('|',),('^',),('&',),('==','!=','<>'),('<=','>=','<','>'),('<<','>>'), ('+','-'),('*','/','%'),('**',)]: @@ -141,9 +154,17 @@ def expr(s: str) -> UOp: flipped = op in ('>', '>=') if flipped: lhs, rhs = rhs, lhs tag = 'flipped' if flipped else ('<>' if op == '<>' else None) - return UOp(_BINOPS[op], dtypes.void, (lhs, rhs), tag=tag) + uop_op = _BINOPS[op] + # TODO: enable this + #assert lhs.dtype != dtypes.void or rhs.dtype != dtypes.void, f"lhs/rhs can't both be void in {uop_op}" + # TODO: check that dtypes match + output_dtype = lhs.dtype if lhs.dtype != dtypes.void else rhs.dtype + if uop_op in {Ops.CMPNE, Ops.CMPEQ, Ops.CMPLE, Ops.CMPLT}: output_dtype = dtypes.bool + return UOp(uop_op, output_dtype, (lhs, rhs), tag=tag) # Unary ops - if s[0] in '-~!' and len(s) > 1 and (s[0] != '!' or s[1] != '='): return UOp(_UNOPS[s[0]], dtypes.void, (expr(s[1:]),)) + if s[0] in '-~!' and len(s) > 1 and (s[0] != '!' or s[1] != '='): + src = expr(s[1:]) + return UOp(_UNOPS[s[0]], src.dtype, (src,)) # Slice/Index -> CUSTOMI if '[' in s and s[-1] == ']': d = 0 @@ -154,7 +175,9 @@ def expr(s: str) -> UOp: b, n = s[:i], s[i+1:-1] if '+:' in n: # Verilog [start +: width] st, w = expr(n.split('+:', 1)[0]), expr(n.split('+:', 1)[1]) + # TODO: correct dtypes here hi = UOp(Ops.SUB, dtypes.void, (UOp(Ops.ADD, dtypes.void, (st, w)), UOp(Ops.CONST, dtypes.int32, arg=1))) + # TODO: use SHRINK instead of CUSTOMI, it's more sensible even if it doesn't perfectly match spec return UOp(Ops.CUSTOMI, dtypes.void, (expr(b), hi, st)) if ':' in n and '?' not in n: d = 0 @@ -172,6 +195,7 @@ def expr(s: str) -> UOp: if s[:5] == 'eval ': return UOp(Ops.DEFINE_VAR, dtypes.void, arg=(s, None, None)) if re.match(r'^[A-Za-z_][\w.]*$', s): return UOp(Ops.DEFINE_VAR, dtypes.void, arg=(s, None, None)) # Numeric literal + # TODO: hex constants are unsigned even without U try: if s[:2].lower() == '0x': m = re.match(r'0[xX]([0-9a-fA-F]+)([UuLl]*)$', s) @@ -191,6 +215,7 @@ def expr(s: str) -> UOp: raise ValueError(f"Cannot parse expression: {s}") def stmt(line: str) -> Stmt|None: + # TODO: track dtypes of variables here on inputs. SCC is dtypes.bool by default. ADDR is dtypes.uint64 line = line.split('//')[0].strip().rstrip(';').rstrip('.') if not line: return None if line == 'break': return Break() @@ -203,13 +228,15 @@ def stmt(line: str) -> Stmt|None: t = t.split('[')[0] # strip array suffix like [64] if m := re.match(r"^(\d+)'([IUFB])$", t): dt = _QDTYPES[f"{m[2].lower()}{m[1]}"] + # TODO: track the dtypes of the created variables here return Declare(n.strip(), dt.vec(vec_count) if vec_count > 1 else dt) return None # unsupported declare type for op, uop in [('+=', Ops.ADD), ('-=', Ops.SUB), ('|=', Ops.OR), ('&=', Ops.AND), ('^=', Ops.XOR), ('<<=', Ops.SHL), ('>>=', Ops.SHR)]: if op in line: l, r = line.split(op, 1) - lhs = expr(l) - return Assign(lhs, UOp(uop, dtypes.void, (lhs, expr(r)))) + lhs, rhs = expr(l), expr(r) + # TODO: track the dtypes of the created variables here + return Assign(lhs, UOp(uop, dtypes.void, (lhs, rhs))) if '=' in line and not any(line[:k] == p for k, p in [(3,'if '),(6,'elsif '),(4,'for ')]): # Find leftmost assignment = (not ==, <=, >=, !=) for chained assignment support eq = -1 diff --git a/extra/assembly/amd/test/bench_emu.py b/extra/assembly/amd/test/bench_emu.py index 1a8871d133..55e73655c1 100644 --- a/extra/assembly/amd/test/bench_emu.py +++ b/extra/assembly/amd/test/bench_emu.py @@ -2,6 +2,7 @@ """Benchmark comparing Python vs Rust RDNA3 emulators on real tinygrad kernels.""" import ctypes, time, os from pathlib import Path +from tinygrad.helpers import getenv, Profiling # Set AMD=1 before importing tinygrad os.environ["AMD"] = "1" @@ -125,6 +126,7 @@ def main(): import argparse parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators") parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark") + parser.add_argument("--profile", action='store_true', help="Enable profiler") args = parser.parse_args() rust_remu = get_rust_remu() @@ -159,7 +161,8 @@ def main(): buffers, args_arr, args_ptr, ranges = setup_buffers(buf_sizes, buf_data) set_valid_mem_ranges(ranges) - py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations) + with Profiling(enabled=args.profile): + py_time = benchmark_emulator("Python", python_run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations) rust_time = benchmark_emulator("Rust", rust_remu.run_asm, kernel, global_size, local_size, args_ptr, rsrc2, args.iterations) if rust_remu else None if py_time: diff --git a/extra/assembly/amd/ucode.py b/extra/assembly/amd/ucode.py index a9a8f8e04d..f37498396e 100644 --- a/extra/assembly/amd/ucode.py +++ b/extra/assembly/amd/ucode.py @@ -574,7 +574,14 @@ _DTYPE_ACCESSOR = {dtypes.uint8: 'u8', dtypes.int8: 'i8', dtypes.uint16: 'u16', def _compile_pseudocode(pseudocode: str, mem_buf: UOp = MEM_BUF) -> tuple[UOp, list[tuple[str, DType]], dict[str, UOp], list[UOp]]: ctx = Ctx(mem_buf=mem_buf) - for stmt in parse(pseudocode): _stmt(stmt, ctx) + try: + stmts = parse(pseudocode) + except AssertionError as e: + print("issue parsing") + print(pseudocode) + print(e) + raise + for stmt in stmts: _stmt(stmt, ctx) return UOp(Ops.SINK, dtypes.void, tuple(u for _, u, _ in ctx.outputs) or ()), [(n, d) for n, _, d in ctx.outputs], INPUT_VARS, ctx.mem_stores def _make_fn(sink: UOp, output_info: list[tuple[str, DType]], input_vars: dict[str, UOp], mem_stores: list[UOp]):