Files
tinygrad/extra/thunder/cuda/matmul.py
2025-10-30 19:09:27 -07:00

56 lines
1.7 KiB
Python

import pathlib
from tinygrad import Device, Tensor
from tinygrad.helpers import Context, getenv
from tinygrad.runtime.support.compiler_cuda import pretty_ptx, NVCCCompiler
if __name__ == "__main__":
if getenv("MATMUL2"):
code = (pathlib.Path(__file__).parent / "matmul2.cu").read_text()
else:
code = (pathlib.Path(__file__).parent / "matmul.cu").read_text()
device = Device["CUDA"]
kitten_args = [f"-I{(pathlib.Path(__file__).parent / 'include').as_posix()}", "-std=c++20", "--expt-relaxed-constexpr"]
lib = NVCCCompiler(device.compiler.arch, kitten_args).compile(code)
kernel_name = lib.decode().split(".globl\t")[1].split("\n")[0]
print("kernel name", kernel_name)
print(pretty_ptx(lib.decode()))
prg = device.runtime(kernel_name, lib)
if getenv("MATMUL2"):
prg.smem = 16384 * 2
else:
prg.smem = 10000
N = 8192
a = Tensor.randn(N, N, device='CUDA', dtype="bfloat16")
b = Tensor.randn(N, N, device='CUDA', dtype="bfloat16")
c = Tensor.empty(N, N, device='CUDA', dtype="bfloat16")
Tensor.realize(a, b, c)
WARP_THREADS = 32
if getenv("MATMUL2"):
SUPER_N = 2
SUPER_M = 2
NUM_WORKERS = SUPER_N * SUPER_M
BLOCK_SIZE = 32
gsz = (N // (BLOCK_SIZE * SUPER_N), N // (BLOCK_SIZE * SUPER_M), 1)
else:
NUM_WORKERS = 1
BLOCK_SIZE = 32
gsz = (N // (BLOCK_SIZE), N // (BLOCK_SIZE), 1)
for _ in range(5):
et = prg(c.uop.buffer.ensure_allocated()._buf, a.uop.buffer._buf, b.uop.buffer._buf,
global_size=gsz, local_size=(NUM_WORKERS*WARP_THREADS,1,1), wait=True)
print(f"{N*N*N*2/(et*1e9):2f} GFLOPS")
# print(c.tolist())
for _ in range(5):
with Context(DEBUG=2):
ref = (a@b).realize()
ref, c = ref.float(), c.float()
print((ref-c).mean().item(), (ref-c).max().item())