From e85f6594b35c0ecda4c4cdbfcb8e64aceed2272e Mon Sep 17 00:00:00 2001 From: George Hotz Date: Sat, 3 Jan 2026 13:22:22 -0800 Subject: [PATCH] simpler and faster --- extra/assembly/amd/dsl.py | 8 +- extra/assembly/amd/pcode.py | 12 +-- extra/assembly/amd/test/bench_emu.py | 143 +++++++++++++++++++++++++++ 3 files changed, 151 insertions(+), 12 deletions(-) diff --git a/extra/assembly/amd/dsl.py b/extra/assembly/amd/dsl.py index 0e34374fd8..abb0ce8cbd 100644 --- a/extra/assembly/amd/dsl.py +++ b/extra/assembly/amd/dsl.py @@ -14,17 +14,13 @@ _struct_f, _struct_I = struct.Struct(" 0 else 0xff800000 try: - bits = _struct_I.unpack(_struct_f.pack(f))[0] + bits = _struct_I.unpack(_struct_f.pack(float(f)))[0] # RDNA3 default mode: flush f32 denormals to zero (FTZ) if (bits & 0x7f800000) == 0 and (bits & 0x007fffff) != 0: return 0x80000000 if bits & 0x80000000 else 0 return bits diff --git a/extra/assembly/amd/pcode.py b/extra/assembly/amd/pcode.py index db6bde512e..5375052a84 100644 --- a/extra/assembly/amd/pcode.py +++ b/extra/assembly/amd/pcode.py @@ -381,10 +381,10 @@ class TypedView: def __index__(self): return int(self) def __trunc__(self): return int(float(self)) if self._float else int(self) def __float__(self): - if self._float: - if self._bf16: return _bf16(self._val) # bf16 uses different conversion - return _f16(self._val) if self._bits == 16 else _f32(self._val) if self._bits == 32 else _f64(self._val) - return float(int(self)) + if not self._float: return float(self._val) + if self._bf16: return _bf16(self._val) + if self._bits == 32: return _f32(self._val) + return _f16(self._val) if self._bits == 16 else _f64(self._val) # Arithmetic - floats use float(), ints use int() def __add__(s, o): return float(s) + float(o) if s._float else int(s) + int(o) @@ -447,11 +447,11 @@ class Reg: u64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64)) i64 = property(lambda s: TypedView(s, 64, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK64)) b64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64)) - f64 = property(lambda s: TypedView(s, 64, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(float(v)))) + f64 = property(lambda s: TypedView(s, 64, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(v))) u32 = property(lambda s: TypedView(s, 32), lambda s, v: setattr(s, '_val', int(v) & MASK32)) i32 = property(lambda s: TypedView(s, 32, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK32)) b32 = property(lambda s: TypedView(s, 32), lambda s, v: setattr(s, '_val', int(v) & MASK32)) - f32 = property(lambda s: TypedView(s, 32, is_float=True), lambda s, v: setattr(s, '_val', _i32(float(v)))) + f32 = property(lambda s: TypedView(s, 32, is_float=True), lambda s, v: setattr(s, '_val', _i32(v))) u24 = property(lambda s: TypedView(s, 24)) i24 = property(lambda s: TypedView(s, 24, signed=True)) u16 = property(lambda s: TypedView(s, 16), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff))) diff --git a/extra/assembly/amd/test/bench_emu.py b/extra/assembly/amd/test/bench_emu.py index 1a8871d133..9cee2a378a 100644 --- a/extra/assembly/amd/test/bench_emu.py +++ b/extra/assembly/amd/test/bench_emu.py @@ -7,6 +7,9 @@ from pathlib import Path os.environ["AMD"] = "1" from extra.assembly.amd.emu import run_asm as python_run_asm, set_valid_mem_ranges, decode_program +from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP2Op_V_ADD_F32, _VOP2Op_V_MUL_F32, _VOP2Op_V_FMAC_F32, _VOP2Op_V_LSHLREV_B32, _VOP2Op_V_AND_B32 +from extra.assembly.amd.pcode import Reg +from extra.assembly.amd.dsl import _i32 REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so" if not REMU_PATH.exists(): @@ -121,12 +124,152 @@ def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], d TINYGRAD_TESTS = ["add", "mul", "reduce_sum", "softmax", "exp", "gelu", "matmul_small"] +# ═══════════════════════════════════════════════════════════════════════════════ +# PCODE MICROBENCHMARKS - test individual pcode function performance +# ═══════════════════════════════════════════════════════════════════════════════ + +def microbench_pcode(iterations: int = 100000): + """Microbenchmark individual pcode functions to identify Reg/TypedView overhead.""" + print("\n" + "=" * 90) + print("PCODE MICROBENCHMARKS") + print("=" * 90) + + # Test values (as raw ints, like the emulator passes them) + f32_1 = _i32(1.5) + f32_2 = _i32(2.5) + f32_3 = _i32(0.5) + int_5 = 5 + int_mask = 0xff00ff00 + + tests = [ + ("V_ADD_F32", lambda: _VOP2Op_V_ADD_F32(f32_1, f32_2, 0, 0, 0, 0, 0, 0xffffffff, 0, None)), + ("V_MUL_F32", lambda: _VOP2Op_V_MUL_F32(f32_1, f32_2, 0, 0, 0, 0, 0, 0xffffffff, 0, None)), + ("V_FMAC_F32", lambda: _VOP2Op_V_FMAC_F32(f32_1, f32_2, f32_3, f32_3, 0, 0, 0, 0xffffffff, 0, None)), + ("V_LSHLREV_B32", lambda: _VOP2Op_V_LSHLREV_B32(int_5, int_mask, 0, 0, 0, 0, 0, 0xffffffff, 0, None)), + ("V_AND_B32", lambda: _VOP2Op_V_AND_B32(int_mask, 0x12345678, 0, 0, 0, 0, 0, 0xffffffff, 0, None)), + ] + + # Baseline: measure overhead of just calling a lambda + def baseline_fn(): return {'D0': 42} + start = time.perf_counter() + for _ in range(iterations): baseline_fn() + baseline_time = time.perf_counter() - start + print(f"\n{'Baseline (empty fn)':<25} {baseline_time*1e6/iterations:8.3f} µs/call") + print("-" * 50) + + for name, fn in tests: + # Warmup + for _ in range(1000): fn() + # Timed + start = time.perf_counter() + for _ in range(iterations): fn() + elapsed = time.perf_counter() - start + us_per_call = elapsed * 1e6 / iterations + overhead = us_per_call - (baseline_time * 1e6 / iterations) + print(f"{name:<25} {us_per_call:8.3f} µs/call (overhead: {overhead:6.3f} µs)") + + # Measure Reg creation overhead separately + print("\n" + "-" * 50) + print("Component breakdown:") + + # Just Reg creation + start = time.perf_counter() + for _ in range(iterations): Reg(f32_1); Reg(f32_2); Reg(0) + reg_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'3x Reg() creation':<23} {reg_time:8.3f} µs") + + # Reg + property access (no arithmetic) + start = time.perf_counter() + for _ in range(iterations): + r = Reg(f32_1) + _ = r.f32 + prop_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'Reg() + .f32 access':<23} {prop_time:8.3f} µs") + + # Reg + TypedView arithmetic + start = time.perf_counter() + for _ in range(iterations): + r1, r2 = Reg(f32_1), Reg(f32_2) + _ = r1.f32 + r2.f32 + arith_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'2x Reg + .f32 + add':<23} {arith_time:8.3f} µs") + + # Full pcode pattern: Reg creation + property + arithmetic + property setter + _val access + start = time.perf_counter() + for _ in range(iterations): + S0, S1, D0 = Reg(f32_1), Reg(f32_2), Reg(0) + D0.f32 = S0.f32 + S1.f32 + _ = D0._val + full_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'Full pcode pattern':<23} {full_time:8.3f} µs") + + # Dict creation overhead + start = time.perf_counter() + for _ in range(iterations): _ = {'D0': 42} + dict_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'Dict creation':<23} {dict_time:8.3f} µs") + + # TypedView.__float__ overhead + r = Reg(f32_1) + tv = r.f32 # get TypedView once + start = time.perf_counter() + for _ in range(iterations): float(tv) + float_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'TypedView.__float__':<23} {float_time:8.3f} µs") + + # Just _f32 conversion + from extra.assembly.amd.dsl import _f32 + start = time.perf_counter() + for _ in range(iterations): _f32(f32_1) + f32_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'_f32() conversion':<23} {f32_time:8.3f} µs") + + # TypedView._val property + start = time.perf_counter() + for _ in range(iterations): tv._val + val_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'TypedView._val':<23} {val_time:8.3f} µs") + + # TypedView.__add__ (this calls __float__ twice + Python float add) + tv2 = Reg(f32_2).f32 + start = time.perf_counter() + for _ in range(iterations): tv + tv2 + add_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'TypedView + TypedView':<23} {add_time:8.3f} µs") + + # Python float add baseline + pf1, pf2 = 1.5, 2.5 + start = time.perf_counter() + for _ in range(iterations): pf1 + pf2 + py_add_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'Python float + float':<23} {py_add_time:8.3f} µs") + + # Setter: D0.f32 = result + d0 = Reg(0) + result = 4.0 + start = time.perf_counter() + for _ in range(iterations): d0.f32 = result + setter_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'Reg.f32 = float':<23} {setter_time:8.3f} µs") + + # _i32 conversion alone + from extra.assembly.amd.dsl import _i32 as dsl_i32 + start = time.perf_counter() + for _ in range(iterations): dsl_i32(4.0) + i32_time = (time.perf_counter() - start) * 1e6 / iterations + print(f" {'_i32() conversion':<23} {i32_time:8.3f} µs") + def main(): import argparse parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators") parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark") + parser.add_argument("--ubench", action="store_true", help="Run pcode microbenchmarks only") args = parser.parse_args() + if args.ubench: + microbench_pcode() + return + rust_remu = get_rust_remu() if rust_remu is None: print("Rust libremu not found. Build with: cargo build --release --manifest-path extra/remu/Cargo.toml")