mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-10 07:28:15 -05:00
metal matmul from tcores branch
This commit is contained in:
@@ -1,9 +1,10 @@
|
||||
import os
|
||||
os.environ["METAL"] = "1"
|
||||
import numpy as np
|
||||
from tinygrad.helpers import dtypes
|
||||
from tinygrad.helpers import dtypes, getenv
|
||||
from tinygrad.runtime.ops_metal import RawMetalBuffer, MetalProgram
|
||||
|
||||
|
||||
N = 2048
|
||||
N = getenv("N", 2048)
|
||||
|
||||
a = RawMetalBuffer(N*N, dtypes.float32)
|
||||
|
||||
@@ -13,31 +14,16 @@ b = RawMetalBuffer.fromCPU(nb)
|
||||
c = RawMetalBuffer.fromCPU(nc)
|
||||
|
||||
FLOPS = N*N*N*2
|
||||
BW = N*N*3
|
||||
BW = N*N*3*4
|
||||
|
||||
prog = MetalProgram("test", f"""
|
||||
#include <metal_stdlib>
|
||||
#include <metal_simdgroup_matrix> // Available from Metal version 2.3 released with OS X 11.0+
|
||||
using namespace metal;
|
||||
kernel void test(device float *a, device const float *data1, device const float *data2, uint3 gid [[thread_position_in_grid]], uint3 xid [[threadgroup_position_in_grid]], uint3 lid [[thread_position_in_threadgroup]], uint sidx [[simdgroup_index_in_threadgroup]]) {{
|
||||
// 1-2 simd groups
|
||||
//uint idx = gid.x/32;
|
||||
//uint pos_x = (idx%{N//32}) * 32;
|
||||
//uint pos_y = (idx/{N//32}) * 32;
|
||||
|
||||
// 4 simd groups
|
||||
uint idx = gid.x/128;
|
||||
uint pos_x = (idx%{N//64}) * 64;
|
||||
uint pos_y = (idx/{N//64}) * 64;
|
||||
pos_x += (sidx%2) * 32;
|
||||
pos_y += (sidx/2) * 32;
|
||||
|
||||
// 16 simd groups (slow)
|
||||
/*uint idx = gid.x/512;
|
||||
uint pos_x = (idx%{N//128}) * 128;
|
||||
uint pos_y = (idx/{N//128}) * 128;
|
||||
pos_x += (sidx%4) * 32;
|
||||
pos_y += (sidx/4) * 32;*/
|
||||
a += gid.y * 32 * {N} + gid.z * 32;
|
||||
data1 += gid.y * 32 * {N};
|
||||
data2 += gid.z * 32;
|
||||
|
||||
simdgroup_float8x8 acc[4][4];
|
||||
for (uint i = 0; i < 4; i++) {{
|
||||
@@ -45,21 +31,19 @@ kernel void test(device float *a, device const float *data1, device const float
|
||||
acc[i][j] = simdgroup_float8x8(0);
|
||||
}}
|
||||
}}
|
||||
|
||||
simdgroup_float8x8 A[4];
|
||||
simdgroup_float8x8 B[4];
|
||||
data1 += pos_x * {N};
|
||||
data2 += pos_y;
|
||||
|
||||
for (uint k = 0; k < {N}; k+=8) {{
|
||||
threadgroup_barrier(mem_flags::mem_threadgroup);
|
||||
simdgroup_load(A[0], data1, {N}, ulong2(k, 0));
|
||||
simdgroup_load(A[1], data1, {N}, ulong2(k, 8));
|
||||
simdgroup_load(A[2], data1, {N}, ulong2(k, 16));
|
||||
simdgroup_load(A[3], data1, {N}, ulong2(k, 24));
|
||||
simdgroup_load(B[0], data2, {N}, ulong2(0, k));
|
||||
simdgroup_load(B[1], data2, {N}, ulong2(8, k));
|
||||
simdgroup_load(B[2], data2, {N}, ulong2(16, k));
|
||||
simdgroup_load(B[3], data2, {N}, ulong2(24, k));
|
||||
simdgroup_load(A[0], data1+k+{0*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_load(A[1], data1+k+{8*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_load(A[2], data1+k+{16*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_load(A[3], data1+k+{24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_load(B[0], data2+0+k*{N}, {N}, ulong2(0, 0));
|
||||
simdgroup_load(B[1], data2+8+k*{N}, {N}, ulong2(0, 0));
|
||||
simdgroup_load(B[2], data2+16+k*{N}, {N}, ulong2(0, 0));
|
||||
simdgroup_load(B[3], data2+24+k*{N}, {N}, ulong2(0, 0));
|
||||
|
||||
simdgroup_multiply_accumulate(acc[0][0], A[0], B[0], acc[0][0]);
|
||||
simdgroup_multiply_accumulate(acc[0][1], A[1], B[0], acc[0][1]);
|
||||
@@ -78,19 +62,30 @@ kernel void test(device float *a, device const float *data1, device const float
|
||||
simdgroup_multiply_accumulate(acc[3][2], A[2], B[3], acc[3][2]);
|
||||
simdgroup_multiply_accumulate(acc[3][3], A[3], B[3], acc[3][3]);
|
||||
}}
|
||||
for (uint i = 0; i < 4; i++) {{
|
||||
for (uint j = 0; j < 4; j++) {{
|
||||
simdgroup_store(acc[i][j], a, {N}, ulong2(pos_y+i*8, pos_x+j*8));
|
||||
}}
|
||||
}}
|
||||
simdgroup_store(acc[0][0], a+{0+0*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[1][0], a+{8+0*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[2][0], a+{16+0*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[3][0], a+{24+0*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[0][1], a+{0+8*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[1][1], a+{8+8*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[2][1], a+{16+8*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[3][1], a+{24+8*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[0][2], a+{0+16*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[1][2], a+{8+16*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[2][2], a+{16+16*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[3][2], a+{24+16*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[0][3], a+{0+24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[1][3], a+{8+24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[2][3], a+{16+24*N}, {N}, ulong2(0, 0));
|
||||
simdgroup_store(acc[3][3], a+{24+24*N}, {N}, ulong2(0, 0));
|
||||
}}""")
|
||||
tm = min([prog([N*N//(2*4*4)], [4*32], a, b, c, wait=True) for _ in range(20)])
|
||||
tm = min([prog([32, N//(8*4), N//(8*4)], [32, 1, 4], a, b, c, wait=True) for _ in range(20)])
|
||||
na = a.toCPU().reshape(N,N)
|
||||
comp = nb@nc
|
||||
if N <= 32:
|
||||
print(na)
|
||||
print(comp)
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul, {BW*1e-9/tm:.2f} GB/s")
|
||||
np.testing.assert_allclose(na, comp, atol=1e-3)
|
||||
|
||||
import time, torch, torch.mps
|
||||
@@ -103,4 +98,21 @@ def torch_prog(b, c):
|
||||
torch.mps.synchronize()
|
||||
return time.perf_counter() - st
|
||||
tm = min([torch_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:.2f} GFLOPS matmul in torch")
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in torch")
|
||||
|
||||
from tinygrad.tensor import Tensor
|
||||
from tinygrad.jit import TinyJit
|
||||
from tinygrad.runtime.ops_metal import METAL
|
||||
b = Tensor(nb)
|
||||
c = Tensor(nb)
|
||||
# TODO: slowness without the JIT I suspect comes from a lack of a caching allocator
|
||||
@TinyJit
|
||||
def tiny_jit(b, c):
|
||||
return (b@c).realize()
|
||||
def tiny_prog(b, c):
|
||||
st = time.perf_counter()
|
||||
a = tiny_jit(b, c)
|
||||
METAL.synchronize()
|
||||
return time.perf_counter() - st
|
||||
tm = min([tiny_prog(b, c) for _ in range(20)])
|
||||
print(f"{N*N:10d} {tm*1e6:9.2f} us, would be {FLOPS*1e-9/tm:9.2f} GFLOPS matmul in tinygrad")
|
||||
|
||||
Reference in New Issue
Block a user