mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
metal_conv gets over 10.4 TFLOPS...
This commit is contained in:
44
extra/gemm/metal_conv.py
Normal file
44
extra/gemm/metal_conv.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import os
|
||||
os.environ["METAL"] = "1"
|
||||
import numpy as np
|
||||
|
||||
BS = 64
|
||||
CIN = 256
|
||||
COUT = 256
|
||||
HW = 32
|
||||
K = 3
|
||||
# TODO: this is doing some trick, since with CIN=256 COUT=256 it's over 10.4 TFLOPS.
|
||||
# are winograd convs less flops?
|
||||
FLOPS = BS*K*K*CIN*HW*HW*COUT*2
|
||||
|
||||
nb = np.random.default_rng().standard_normal(size=(BS,CIN,HW,HW), dtype=np.float32)
|
||||
nc = np.random.default_rng().standard_normal(size=(COUT,CIN,K,K), dtype=np.float32)
|
||||
|
||||
import time, torch, torch.mps
|
||||
b = torch.from_numpy(nb).to('mps')
|
||||
c = torch.from_numpy(nc).to('mps')
|
||||
|
||||
def torch_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = torch.nn.functional.conv2d(b, c, padding=1)
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in torch")
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.runtime.ops_metal import METAL
|
||||
b = Tensor(nb)
|
||||
c = Tensor(nc)
|
||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||
@TinyJit
|
||||
def tiny_jit(b, c):
|
||||
return b.conv2d(c, padding=1).realize()
|
||||
def tiny_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
METAL.synchronize()
|
||||
return time.perf_counter() - st
|
||||
tm = min([tiny_prog(b, c) for _ in range(5)])
|
||||
print(f"{tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS conv in tinygrad")
|
||||
@@ -1,5 +1,6 @@
|
||||
import os
|
||||
os.environ["METAL"] = "1"
|
||||
import time
|
||||
import numpy as np
|
||||
from tinygrad.helpers import dtypes, getenv
|
||||
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
|
||||
@@ -79,7 +80,12 @@ kernel void test(device float *a, device const float *data1, device const float
|
||||
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
|
||||
}}""")
|
||||
tm = min([prog([32, N//(8*4), N//(8*4)], [32, 1, 4], a, b, c, wait=True) for _ in range(20)])
|
||||
def timeit(fxn):
|
||||
st = time.perf_counter()
|
||||
et = fxn()
|
||||
# NOTE: et doesn't contain the launch overhead
|
||||
return time.perf_counter() - st
|
||||
tm = min([timeit(lambda: prog([32, N//(8*4), N//(8*4)], [32, 1, 4], a, b, c, wait=True)) for _ in range(20)])
|
||||
na = a.toCPU().reshape(N,N)
|
||||
comp = nb@nc
|
||||
if N <= 32:
|
||||
@@ -88,7 +94,7 @@ if N <= 32:
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
np.testing.assert_allclose(na, comp, atol=1e-3)
|
||||
|
||||
import time, torch, torch.mps
|
||||
import torch, torch.mps
|
||||
b = torch.from_numpy(nb).to('mps')
|
||||
c = torch.from_numpy(nc).to('mps')
|
||||
|
||||
@@ -104,7 +110,7 @@ from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.runtime.ops_metal import METAL
|
||||
b = Tensor(nb)
|
||||
c = Tensor(nb)
|
||||
c = Tensor(nc)
|
||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||
@TinyJit
|
||||
def tiny_jit(b, c):
|
||||
|
||||
Reference in New Issue
Block a user