test flops (and allow wide ALU in UOps) [run_process_replay] (#5749)

* flops test in external_test_speed_theoretical.py

* test speed theo

* min SZMAX

* allow wide ALU for things that support it

* needed for mypy
This commit is contained in:
George Hotz
2024-07-26 21:07:28 -07:00
committed by GitHub
parent 2fde2d2914
commit f8972ace38
2 changed files with 32 additions and 15 deletions

View File

@@ -1,31 +1,48 @@
import time, unittest
from tinygrad import Tensor, TinyJit, Device
from tinygrad.helpers import getenv
from tinygrad import Tensor, TinyJit, Device, dtypes
from tinygrad.helpers import getenv, GlobalCounters
def _test(tcount, fxn, szmax):
print(f"**** testing {fxn.__name__}")
SZMAX = getenv("SZMAX", 10)
SZMIN = min(SZMAX, getenv("SZMIN", 10))
def _test(tcount, fxn, dtype=dtypes.float):
print(f"**** testing {fxn.__name__} {dtype}")
allgbs = []
for sz in range(szmax):
for sz in range(SZMIN, SZMAX+1):
jfxn = TinyJit(fxn)
ts = [Tensor.zeros((2**sz)*1024*1024).contiguous().realize() for _ in range(tcount)]
ts = [Tensor.zeros((2**sz)*1024*1024, dtype=dtype).contiguous().realize() for _ in range(tcount)]
tms = []
for _ in range(10):
ts = [(x+1).realize() for x in ts]
Device.default.synchronize()
GlobalCounters.global_ops = 0
GlobalCounters.global_mem = 0
st = time.perf_counter()
out_nbytes = jfxn(*ts).nbytes()
jfxn(*ts).nbytes()
Device.default.synchronize()
tms.append(time.perf_counter() - st)
gbs = (out_nbytes+sum(x.nbytes() for x in ts))*1e-9/min(tms)
print(f"{ts[0].nbytes()/(1024*1024):10.0f} MB, {min(tms)*1e3:6.2f} ms GB/s {gbs:<10.2f} {str(ts[0].shape):20s}")
ops, mem = GlobalCounters.global_ops, GlobalCounters.global_mem
gflops = ops*1e-9/min(tms)
gbs = mem*1e-9/min(tms)
print(f"{ts[0].nbytes()/(1024*1024):10.0f} MB, {min(tms)*1e3:6.2f} ms {gbs:10.2f} GB/s {gflops:10.2f} GFLOPS {str(ts[0].shape):20s}")
allgbs.append(gbs)
return max(allgbs)
MEMBW = getenv("MEMBW", 10)
class TestTheoreticalSpeed(unittest.TestCase):
def test_add(self): self.assertGreater(_test(2, Tensor.add, 11), MEMBW)
def test_exp(self): self.assertGreater(_test(1, Tensor.exp, 11), MEMBW)
def test_sum(self): self.assertGreater(_test(1, Tensor.sum, 11), MEMBW)
class TestRamBandwidth(unittest.TestCase):
def test_add(self): self.assertGreater(_test(2, Tensor.add), MEMBW)
def test_exp(self): self.assertGreater(_test(1, Tensor.exp), MEMBW)
def test_sum(self): self.assertGreater(_test(1, Tensor.sum), MEMBW)
# ratio between MEM and FLOPS < 1000
# NOTE: On AMD, (x*x)+1 gets ~30 TFLOPS, (x*x)+3 gets ~60 TFLOPS
def flopsmax(x):
for _ in range(500): x = (x*x)+3
return x
class TestFlops(unittest.TestCase):
def test_flops_int8(self): _test(1, flopsmax, dtypes.int8)
def test_flops_fp16(self): _test(1, flopsmax, dtypes.half)
def test_flops_fp32(self): _test(1, flopsmax)
if __name__ == '__main__':
unittest.main()

View File

@@ -205,7 +205,6 @@ def type_verify(uops):
assert dtype is None, f"{uop} dtype must be None, got {dtype}"
if len(src) == 4: assert src[3].dtype == dtypes.bool, f"gate dtype mismatch {src[3].dtype} != {dtypes.bool}"
if uop is UOps.ALU:
assert dtype.count == 1, f"wide ALU is not supported on {dtype}"
if arg in UnaryOps: assert dtype == src[0].dtype, f"{arg} dtype mismatch {dtype=} != {src[0].dtype=}"
elif arg in {BinaryOps.CMPLT, BinaryOps.CMPNE}:
assert dtype == dtypes.bool, f"{arg} output dtype mismatch {dtype=} != {dtypes.bool}"
@@ -256,7 +255,8 @@ def flops_mem(uops:List[UOp], ignore_indexing=False) -> Tuple[sint, sint]:
assert u.src[2].dtype is not None
mem += u.src[2].dtype.itemsize * mults
elif u.op is UOps.ALU and u not in dont_count:
flops += mults * (2 if u.arg == TernaryOps.MULACC else 1)
assert u.dtype is not None
flops += (mults * (2 if u.arg == TernaryOps.MULACC else 1)) * u.dtype.count
elif u.op is UOps.WMMA and u not in dont_count:
assert u.arg[1] is not None
flops += 2 * prod(u.arg[1]) // 32 * mults