mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
jax parallel matmul example
This commit is contained in:
27
extra/gemm/jax_pmatmul.py
Executable file
27
extra/gemm/jax_pmatmul.py
Executable file
@@ -0,0 +1,27 @@
|
||||
#!/usr/bin/env python3
|
||||
import time
|
||||
import jax
|
||||
import jax.numpy as jnp
|
||||
|
||||
print(jax.devices())
|
||||
DEVICES = len(jax.devices())
|
||||
BS = 32
|
||||
N = 4096
|
||||
dtype = jnp.float16
|
||||
A = jnp.zeros((DEVICES, BS, N, N), dtype)
|
||||
B = jnp.zeros((1, 1, N, N), dtype)
|
||||
A = jax.device_put_sharded([A[i] for i in range(DEVICES)], jax.devices())
|
||||
B = jax.device_put_sharded([B for i in range(DEVICES)], jax.devices())
|
||||
|
||||
OPS = DEVICES*BS*N*N*N*2
|
||||
def matmul(A,B): return jnp.matmul(A,B,preferred_element_type=jnp.float32)
|
||||
pmatmul = jax.pmap(matmul)
|
||||
|
||||
MAX_TFLOPS = 123*DEVICES # Peak FP16 Tensor TFLOPS with FP32 Acc (7900XTX)
|
||||
for i in range(10):
|
||||
st = time.perf_counter()
|
||||
C = pmatmul(A,B).block_until_ready()
|
||||
et = time.perf_counter()-st
|
||||
tflops = (OPS*1e-12)/et
|
||||
print(f"time {et*1e3:.2f} ms, TFLOPS {tflops:6.2f}, MFU {(tflops/MAX_TFLOPS)*100:4.2f}% out shape {C.shape} dtype {C.dtype}")
|
||||
|
||||
Reference in New Issue
Block a user