From 3f137b134af51814fbcdd7a055b6d79757f66414 Mon Sep 17 00:00:00 2001 From: George Hotz Date: Tue, 28 Nov 2023 13:48:11 -0800 Subject: [PATCH] jax parallel matmul example --- extra/gemm/jax_pmatmul.py | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100755 extra/gemm/jax_pmatmul.py diff --git a/extra/gemm/jax_pmatmul.py b/extra/gemm/jax_pmatmul.py new file mode 100755 index 0000000000..b69a2b9b47 --- /dev/null +++ b/extra/gemm/jax_pmatmul.py @@ -0,0 +1,27 @@ +#!/usr/bin/env python3 +import time +import jax +import jax.numpy as jnp + +print(jax.devices()) +DEVICES = len(jax.devices()) +BS = 32 +N = 4096 +dtype = jnp.float16 +A = jnp.zeros((DEVICES, BS, N, N), dtype) +B = jnp.zeros((1, 1, N, N), dtype) +A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices()) +B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices()) + +OPS = DEVICES*BS*N*N*N*2 +def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32) +pmatmul = jax.pmap(matmul) + +MAX_TFLOPS = 123*DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX) +for i in range(10): + st = time.perf_counter() + C = pmatmul(A,B).block_until_ready() + et = time.perf_counter()-st + tflops = (OPS*1e-12)/et + print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}") +