metal_conv gets over 10.4 TFLOPS...

This commit is contained in:
George Hotz
2023-04-15 03:31:22 -07:00
parent d66e682205
commit 8b777af571
2 changed files with 53 additions and 3 deletions

44
extra/gemm/metal_conv.py Normal file
View File

@@ -0,0 +1,44 @@
import os
os.environ["METAL"] = "1"
import numpy as np
BS = 64
CIN = 256
COUT = 256
HW = 32
K = 3
# TODO: this is doing some trick, since with CIN=256 COUT=256 it's over 10.4 TFLOPS.
# are winograd convs less flops?
FLOPS = BS*K*K*CIN*HW*HW*COUT*2
nb = np.random.default_rng().standard_normal(size=(BS,CIN,HW,HW), dtype=np.float32)
nc = np.random.default_rng().standard_normal(size=(COUT,CIN,K,K), dtype=np.float32)
import time, torch, torch.mps
b = torch.from_numpy(nb).to('mps')
c = torch.from_numpy(nc).to('mps')
def torch_prog(b, c):
st = time.perf_counter()
a = torch.nn.functional.conv2d(b, c, padding=1)
torch.mps.synchronize()
return time.perf_counter() - st
tm = min([torch_prog(b, c) for _ in range(20)])
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch")
from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.runtime.ops_metal import METAL
b = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
@TinyJit
def tiny_jit(b, c):
return b.conv2d(c, padding=1).realize()
def tiny_prog(b, c):
st = time.perf_counter()
a = tiny_jit(b, c)
METAL.synchronize()
return time.perf_counter() - st
tm = min([tiny_prog(b, c) for _ in range(5)])
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in tinygrad")

View File

@@ -1,5 +1,6 @@
import os
os.environ["METAL"] = "1"
import time
import numpy as np
from tinygrad.helpers import dtypes, getenv
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
@@ -79,7 +80,12 @@ kernel void test(device float *a, device const float *data1, device const float
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
}}""")
tm = min([prog([32, N//(8*4), N//(8*4)], [32, 1, 4], a, b, c, wait=True) for _ in range(20)])
def timeit(fxn):
st = time.perf_counter()
et = fxn()
# NOTE: et doesn't contain the launch overhead
return time.perf_counter() - st
tm = min([timeit(lambda: prog([32, N//(8*4), N//(8*4)], [32, 1, 4], a, b, c, wait=True)) for _ in range(20)])
na = a.toCPU().reshape(N,N)
comp = nb@nc
if N <= 32:
@@ -88,7 +94,7 @@ if N <= 32:
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
np.testing.assert_allclose(na, comp, atol=1e-3)
import time, torch, torch.mps
import torch, torch.mps
b = torch.from_numpy(nb).to('mps')
c = torch.from_numpy(nc).to('mps')
@@ -104,7 +110,7 @@ from tinygrad.tensor import Tensor
from tinygrad.jit import TinyJit
from tinygrad.runtime.ops_metal import METAL
b = Tensor(nb)
c = Tensor(nb)
c = Tensor(nc)
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
@TinyJit
def tiny_jit(b, c):