mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 15:08:02 -05:00
@@ -1,4 +1,5 @@
|
||||
# include directory copied from https://github.com/HazyResearch/ThunderMittens
|
||||
# https://hazyresearch.stanford.edu/blog/2024-11-28-tk-mlx
|
||||
|
||||
gemm = """
|
||||
#include <metal_stdlib>
|
||||
@@ -41,10 +42,9 @@ kernel void matmul_naive(GEMM_PARAMS_DEF(T)) {
|
||||
instantiate_matmul_custom(float32, float);
|
||||
"""
|
||||
|
||||
from tinygrad import Device, Tensor
|
||||
from tinygrad import Device, Tensor, Context
|
||||
|
||||
if __name__ == "__main__":
|
||||
# TODO: why isn't this type inferred?
|
||||
device = Device["METAL"]
|
||||
lib = device.compiler.compile(gemm)
|
||||
prg = device.runtime("matmul_custom_float32", lib)
|
||||
@@ -65,7 +65,10 @@ if __name__ == "__main__":
|
||||
global_size=gsz, local_size=(32,1,1), vals=(N, N, N), wait=True)
|
||||
print(f"{N*N*N*2/(et*1e9):2f} GFLOPS")
|
||||
|
||||
val = ((a@b).contiguous()-c).mean()
|
||||
print(val.item())
|
||||
for _ in range(5):
|
||||
with Context(DEBUG=2):
|
||||
ref = (a@b).realize()
|
||||
|
||||
print((ref-c).mean().item())
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user