From f8972ace381de0cc924754b3631cd62d2dadcc07 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Fri, 26 Jul 2024 21:07:28 -0700 Subject: [PATCH] test flops (and allow wide ALU in UOps) [run_process_replay] (#5749) * flops test in external_test_speed_theoretical.py * test speed theo * min SZMAX * allow wide ALU for things that support it * needed for mypy --- .../external_test_speed_theoretical.py | 43 +++++++++++++------ tinygrad/codegen/uops.py | 4 +- 2 files changed, 32 insertions(+), 15 deletions(-) diff --git a/test/external/external_test_speed_theoretical.py b/test/external/external_test_speed_theoretical.py index fd2f825265..ea0f260425 100644 --- a/test/external/external_test_speed_theoretical.py +++ b/test/external/external_test_speed_theoretical.py @@ -1,31 +1,48 @@ import time, unittest -from tinygrad import Tensor, TinyJit, Device -from tinygrad.helpers import getenv +from tinygrad import Tensor, TinyJit, Device, dtypes +from tinygrad.helpers import getenv, GlobalCounters -def _test(tcount, fxn, szmax): - print(f"**** testing {fxn.__name__}") +SZMAX = getenv("SZMAX", 10) +SZMIN = min(SZMAX, getenv("SZMIN", 10)) +def _test(tcount, fxn, dtype=dtypes.float): + print(f"**** testing {fxn.__name__} {dtype}") allgbs = [] - for sz in range(szmax): + for sz in range(SZMIN, SZMAX+1): jfxn = TinyJit(fxn) - ts = [Tensor.zeros((2**sz)*1024*1024).contiguous().realize() for _ in range(tcount)] + ts = [Tensor.zeros((2**sz)*1024*1024, dtype=dtype).contiguous().realize() for _ in range(tcount)] tms = [] for _ in range(10): ts = [(x+1).realize() for x in ts] Device.default.synchronize() + GlobalCounters.global_ops = 0 + GlobalCounters.global_mem = 0 st = time.perf_counter() - out_nbytes = jfxn(*ts).nbytes() + jfxn(*ts).nbytes() Device.default.synchronize() tms.append(time.perf_counter() - st) - gbs = (out_nbytes+sum(x.nbytes() for x in ts))*1e-9/min(tms) - print(f"{ts[0].nbytes()/(1024*1024):10.0f} MB, {min(tms)*1e3:6.2f} ms GB/s {gbs:<10.2f} {str(ts[0].shape):20s}") + ops, mem = GlobalCounters.global_ops, GlobalCounters.global_mem + gflops = ops*1e-9/min(tms) + gbs = mem*1e-9/min(tms) + print(f"{ts[0].nbytes()/(1024*1024):10.0f} MB, {min(tms)*1e3:6.2f} ms {gbs:10.2f} GB/s {gflops:10.2f} GFLOPS {str(ts[0].shape):20s}") allgbs.append(gbs) return max(allgbs) MEMBW = getenv("MEMBW", 10) -class TestTheoreticalSpeed(unittest.TestCase): - def test_add(self): self.assertGreater(_test(2, Tensor.add, 11), MEMBW) - def test_exp(self): self.assertGreater(_test(1, Tensor.exp, 11), MEMBW) - def test_sum(self): self.assertGreater(_test(1, Tensor.sum, 11), MEMBW) +class TestRamBandwidth(unittest.TestCase): + def test_add(self): self.assertGreater(_test(2, Tensor.add), MEMBW) + def test_exp(self): self.assertGreater(_test(1, Tensor.exp), MEMBW) + def test_sum(self): self.assertGreater(_test(1, Tensor.sum), MEMBW) + +# ratio between MEM and FLOPS < 1000 +# NOTE: On AMD, (x*x)+1 gets ~30 TFLOPS, (x*x)+3 gets ~60 TFLOPS +def flopsmax(x): + for _ in range(500): x = (x*x)+3 + return x + +class TestFlops(unittest.TestCase): + def test_flops_int8(self): _test(1, flopsmax, dtypes.int8) + def test_flops_fp16(self): _test(1, flopsmax, dtypes.half) + def test_flops_fp32(self): _test(1, flopsmax) if __name__ == '__main__': unittest.main() diff --git a/tinygrad/codegen/uops.py b/tinygrad/codegen/uops.py index 7a8eaed569..276b647d6a 100644 --- a/tinygrad/codegen/uops.py +++ b/tinygrad/codegen/uops.py @@ -205,7 +205,6 @@ def type_verify(uops): assert dtype is None, f"{uop} dtype must be None, got {dtype}" if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}" if uop is UOps.ALU: - assert dtype.count == 1, f"wide ALU is not supported on {dtype}" if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}" elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}: assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}" @@ -256,7 +255,8 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]: assert u.src[2].dtype is not None mem += u.src[2].dtype.itemsize * mults elif u.op is UOps.ALU and u not in dont_count: - flops += mults * (2 if u.arg == TernaryOps.MULACC else 1) + assert u.dtype is not None + flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count elif u.op is UOps.WMMA and u not in dont_count: assert u.arg[1] is not None flops += 2 * prod(u.arg[1]) // 32 * mults