mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-06 21:53:53 -05:00
simpler and faster
This commit is contained in:
@@ -14,17 +14,13 @@ _struct_f, _struct_I = struct.Struct("<f"), struct.Struct("<I")
|
||||
_struct_e, _struct_H = struct.Struct("<e"), struct.Struct("<H")
|
||||
_struct_d, _struct_Q = struct.Struct("<d"), struct.Struct("<Q")
|
||||
def _f32(i):
|
||||
i = i & MASK32
|
||||
# RDNA3 default mode: flush f32 denormals to zero (FTZ)
|
||||
# Denormal: exponent=0 (bits 23-30) and mantissa!=0 (bits 0-22)
|
||||
if (i & 0x7f800000) == 0 and (i & 0x007fffff) != 0: return 0.0
|
||||
return _struct_f.unpack(_struct_I.pack(i))[0]
|
||||
return _struct_f.unpack(_struct_I.pack(i & MASK32))[0]
|
||||
def _i32(f):
|
||||
if isinstance(f, int): f = float(f)
|
||||
if math.isnan(f): return 0xffc00000 if math.copysign(1.0, f) < 0 else 0x7fc00000
|
||||
if math.isinf(f): return 0x7f800000 if f > 0 else 0xff800000
|
||||
try:
|
||||
bits = _struct_I.unpack(_struct_f.pack(f))[0]
|
||||
bits = _struct_I.unpack(_struct_f.pack(float(f)))[0]
|
||||
# RDNA3 default mode: flush f32 denormals to zero (FTZ)
|
||||
if (bits & 0x7f800000) == 0 and (bits & 0x007fffff) != 0: return 0x80000000 if bits & 0x80000000 else 0
|
||||
return bits
|
||||
|
||||
@@ -381,10 +381,10 @@ class TypedView:
|
||||
def __index__(self): return int(self)
|
||||
def __trunc__(self): return int(float(self)) if self._float else int(self)
|
||||
def __float__(self):
|
||||
if self._float:
|
||||
if self._bf16: return _bf16(self._val) # bf16 uses different conversion
|
||||
return _f16(self._val) if self._bits == 16 else _f32(self._val) if self._bits == 32 else _f64(self._val)
|
||||
return float(int(self))
|
||||
if not self._float: return float(self._val)
|
||||
if self._bf16: return _bf16(self._val)
|
||||
if self._bits == 32: return _f32(self._val)
|
||||
return _f16(self._val) if self._bits == 16 else _f64(self._val)
|
||||
|
||||
# Arithmetic - floats use float(), ints use int()
|
||||
def __add__(s, o): return float(s) + float(o) if s._float else int(s) + int(o)
|
||||
@@ -447,11 +447,11 @@ class Reg:
|
||||
u64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
i64 = property(lambda s: TypedView(s, 64, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
b64 = property(lambda s: TypedView(s, 64), lambda s, v: setattr(s, '_val', int(v) & MASK64))
|
||||
f64 = property(lambda s: TypedView(s, 64, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(float(v))))
|
||||
f64 = property(lambda s: TypedView(s, 64, is_float=True), lambda s, v: setattr(s, '_val', v if isinstance(v, int) else _i64(v)))
|
||||
u32 = property(lambda s: TypedView(s, 32), lambda s, v: setattr(s, '_val', int(v) & MASK32))
|
||||
i32 = property(lambda s: TypedView(s, 32, signed=True), lambda s, v: setattr(s, '_val', int(v) & MASK32))
|
||||
b32 = property(lambda s: TypedView(s, 32), lambda s, v: setattr(s, '_val', int(v) & MASK32))
|
||||
f32 = property(lambda s: TypedView(s, 32, is_float=True), lambda s, v: setattr(s, '_val', _i32(float(v))))
|
||||
f32 = property(lambda s: TypedView(s, 32, is_float=True), lambda s, v: setattr(s, '_val', _i32(v)))
|
||||
u24 = property(lambda s: TypedView(s, 24))
|
||||
i24 = property(lambda s: TypedView(s, 24, signed=True))
|
||||
u16 = property(lambda s: TypedView(s, 16), lambda s, v: setattr(s, '_val', (s._val & 0xffff0000) | (int(v) & 0xffff)))
|
||||
|
||||
@@ -7,6 +7,9 @@ from pathlib import Path
|
||||
os.environ["AMD"] = "1"
|
||||
|
||||
from extra.assembly.amd.emu import run_asm as python_run_asm, set_valid_mem_ranges, decode_program
|
||||
from extra.assembly.amd.autogen.rdna3.gen_pcode import _VOP2Op_V_ADD_F32, _VOP2Op_V_MUL_F32, _VOP2Op_V_FMAC_F32, _VOP2Op_V_LSHLREV_B32, _VOP2Op_V_AND_B32
|
||||
from extra.assembly.amd.pcode import Reg
|
||||
from extra.assembly.amd.dsl import _i32
|
||||
|
||||
REMU_PATH = Path(__file__).parents[3] / "remu/target/release/libremu.so"
|
||||
if not REMU_PATH.exists():
|
||||
@@ -121,12 +124,152 @@ def get_tinygrad_kernel(op_name: str) -> tuple[bytes, tuple, tuple, list[int], d
|
||||
|
||||
TINYGRAD_TESTS = ["add", "mul", "reduce_sum", "softmax", "exp", "gelu", "matmul_small"]
|
||||
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
# PCODE MICROBENCHMARKS - test individual pcode function performance
|
||||
# ═══════════════════════════════════════════════════════════════════════════════
|
||||
|
||||
def microbench_pcode(iterations: int = 100000):
|
||||
"""Microbenchmark individual pcode functions to identify Reg/TypedView overhead."""
|
||||
print("\n" + "=" * 90)
|
||||
print("PCODE MICROBENCHMARKS")
|
||||
print("=" * 90)
|
||||
|
||||
# Test values (as raw ints, like the emulator passes them)
|
||||
f32_1 = _i32(1.5)
|
||||
f32_2 = _i32(2.5)
|
||||
f32_3 = _i32(0.5)
|
||||
int_5 = 5
|
||||
int_mask = 0xff00ff00
|
||||
|
||||
tests = [
|
||||
("V_ADD_F32", lambda: _VOP2Op_V_ADD_F32(f32_1, f32_2, 0, 0, 0, 0, 0, 0xffffffff, 0, None)),
|
||||
("V_MUL_F32", lambda: _VOP2Op_V_MUL_F32(f32_1, f32_2, 0, 0, 0, 0, 0, 0xffffffff, 0, None)),
|
||||
("V_FMAC_F32", lambda: _VOP2Op_V_FMAC_F32(f32_1, f32_2, f32_3, f32_3, 0, 0, 0, 0xffffffff, 0, None)),
|
||||
("V_LSHLREV_B32", lambda: _VOP2Op_V_LSHLREV_B32(int_5, int_mask, 0, 0, 0, 0, 0, 0xffffffff, 0, None)),
|
||||
("V_AND_B32", lambda: _VOP2Op_V_AND_B32(int_mask, 0x12345678, 0, 0, 0, 0, 0, 0xffffffff, 0, None)),
|
||||
]
|
||||
|
||||
# Baseline: measure overhead of just calling a lambda
|
||||
def baseline_fn(): return {'D0': 42}
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): baseline_fn()
|
||||
baseline_time = time.perf_counter() - start
|
||||
print(f"\n{'Baseline (empty fn)':<25} {baseline_time*1e6/iterations:8.3f} µs/call")
|
||||
print("-" * 50)
|
||||
|
||||
for name, fn in tests:
|
||||
# Warmup
|
||||
for _ in range(1000): fn()
|
||||
# Timed
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): fn()
|
||||
elapsed = time.perf_counter() - start
|
||||
us_per_call = elapsed * 1e6 / iterations
|
||||
overhead = us_per_call - (baseline_time * 1e6 / iterations)
|
||||
print(f"{name:<25} {us_per_call:8.3f} µs/call (overhead: {overhead:6.3f} µs)")
|
||||
|
||||
# Measure Reg creation overhead separately
|
||||
print("\n" + "-" * 50)
|
||||
print("Component breakdown:")
|
||||
|
||||
# Just Reg creation
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): Reg(f32_1); Reg(f32_2); Reg(0)
|
||||
reg_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'3x Reg() creation':<23} {reg_time:8.3f} µs")
|
||||
|
||||
# Reg + property access (no arithmetic)
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
r = Reg(f32_1)
|
||||
_ = r.f32
|
||||
prop_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'Reg() + .f32 access':<23} {prop_time:8.3f} µs")
|
||||
|
||||
# Reg + TypedView arithmetic
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
r1, r2 = Reg(f32_1), Reg(f32_2)
|
||||
_ = r1.f32 + r2.f32
|
||||
arith_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'2x Reg + .f32 + add':<23} {arith_time:8.3f} µs")
|
||||
|
||||
# Full pcode pattern: Reg creation + property + arithmetic + property setter + _val access
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations):
|
||||
S0, S1, D0 = Reg(f32_1), Reg(f32_2), Reg(0)
|
||||
D0.f32 = S0.f32 + S1.f32
|
||||
_ = D0._val
|
||||
full_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'Full pcode pattern':<23} {full_time:8.3f} µs")
|
||||
|
||||
# Dict creation overhead
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): _ = {'D0': 42}
|
||||
dict_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'Dict creation':<23} {dict_time:8.3f} µs")
|
||||
|
||||
# TypedView.__float__ overhead
|
||||
r = Reg(f32_1)
|
||||
tv = r.f32 # get TypedView once
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): float(tv)
|
||||
float_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'TypedView.__float__':<23} {float_time:8.3f} µs")
|
||||
|
||||
# Just _f32 conversion
|
||||
from extra.assembly.amd.dsl import _f32
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): _f32(f32_1)
|
||||
f32_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'_f32() conversion':<23} {f32_time:8.3f} µs")
|
||||
|
||||
# TypedView._val property
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): tv._val
|
||||
val_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'TypedView._val':<23} {val_time:8.3f} µs")
|
||||
|
||||
# TypedView.__add__ (this calls __float__ twice + Python float add)
|
||||
tv2 = Reg(f32_2).f32
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): tv + tv2
|
||||
add_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'TypedView + TypedView':<23} {add_time:8.3f} µs")
|
||||
|
||||
# Python float add baseline
|
||||
pf1, pf2 = 1.5, 2.5
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): pf1 + pf2
|
||||
py_add_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'Python float + float':<23} {py_add_time:8.3f} µs")
|
||||
|
||||
# Setter: D0.f32 = result
|
||||
d0 = Reg(0)
|
||||
result = 4.0
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): d0.f32 = result
|
||||
setter_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'Reg.f32 = float':<23} {setter_time:8.3f} µs")
|
||||
|
||||
# _i32 conversion alone
|
||||
from extra.assembly.amd.dsl import _i32 as dsl_i32
|
||||
start = time.perf_counter()
|
||||
for _ in range(iterations): dsl_i32(4.0)
|
||||
i32_time = (time.perf_counter() - start) * 1e6 / iterations
|
||||
print(f" {'_i32() conversion':<23} {i32_time:8.3f} µs")
|
||||
|
||||
def main():
|
||||
import argparse
|
||||
parser = argparse.ArgumentParser(description="Benchmark RDNA3 emulators")
|
||||
parser.add_argument("--iterations", type=int, default=3, help="Number of iterations per benchmark")
|
||||
parser.add_argument("--ubench", action="store_true", help="Run pcode microbenchmarks only")
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.ubench:
|
||||
microbench_pcode()
|
||||
return
|
||||
|
||||
rust_remu = get_rust_remu()
|
||||
if rust_remu is None:
|
||||
print("Rust libremu not found. Build with: cargo build --release --manifest-path extra/remu/Cargo.toml")
|
||||
|
||||
Reference in New Issue
Block a user