From bc55c8a30e171033a3cfc89e7c7e037cab14cf67 Mon Sep 17 00:00:00 2001 From: George Hotz <72895+geohot@users.noreply.github.com> Date: Wed, 7 Aug 2024 22:32:11 -0700 Subject: [PATCH] pmatmul example + GB/s bugfix [run_process_replay] (#5974) * pmatmul example + bugfix * improve pmatmul * Update real_pmatmul.py --- extra/gemm/real_pmatmul.py | 20 ++++++++++++++++++++ tinygrad/engine/realize.py | 4 ++-- 2 files changed, 22 insertions(+), 2 deletions(-) create mode 100644 extra/gemm/real_pmatmul.py diff --git a/extra/gemm/real_pmatmul.py b/extra/gemm/real_pmatmul.py new file mode 100644 index 0000000000..b4b0202894 --- /dev/null +++ b/extra/gemm/real_pmatmul.py @@ -0,0 +1,20 @@ +import time +from tinygrad import Tensor, Device, TinyJit +from tinygrad.helpers import getenv + +if __name__ == "__main__": + DEVS = [f"NV:{i}" for i in range(getenv("GPUS", 2))] + N = getenv("N", 8192) + A = Tensor.rand(N, N).shard(DEVS, 0).realize() + B = Tensor.rand(N, N).shard(DEVS, 1).realize() + print("***** MUL *****") + jmatmul = TinyJit(Tensor.dot) + for i in range(10): + Device["NV:0"].synchronize() + Device["NV:1"].synchronize() + st = time.perf_counter() + jmatmul(A, B) + Device["NV:0"].synchronize() + Device["NV:1"].synchronize() + et = time.perf_counter() + print(f"{(N*N*N*2*1e-12)/(et-st):.2f} TFLOPS") diff --git a/tinygrad/engine/realize.py b/tinygrad/engine/realize.py index 5987183a32..a985d5f369 100644 --- a/tinygrad/engine/realize.py +++ b/tinygrad/engine/realize.py @@ -66,9 +66,9 @@ def get_kernel(renderer:Renderer, ast:LazyOp) -> Kernel: # **************** Runners **************** class Runner: - def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:sint=0): + def __init__(self, display_name:str, dname:str, op_estimate:sint=0, mem_estimate:sint=0, lds_estimate:Optional[sint]=None): self.first_run, self.display_name, self.dname, self.op_estimate, self.mem_estimate, self.lds_estimate = \ - True, display_name, dname, op_estimate, mem_estimate, lds_estimate + True, display_name, dname, op_estimate, mem_estimate, mem_estimate if lds_estimate is None else lds_estimate @property def device(self): return Device[self.dname] def exec(self, rawbufs:List[Buffer], var_vals:Optional[Dict[Variable, int]]=None) -> Optional[float]: