merge tagless reshapes (#12474)

* merge tagless reshapes

* cleanup
This commit is contained in:
George Hotz
2025-10-07 13:57:58 +08:00
committed by GitHub
parent 7b48f3cc45
commit 514d2a0774
2 changed files with 10 additions and 4 deletions

View File

@@ -1,4 +1,5 @@
# include directory copied from https://github.com/HazyResearch/ThunderMittens
# https://hazyresearch.stanford.edu/blog/2024-11-28-tk-mlx
gemm = """
#include <metal_stdlib>
@@ -41,10 +42,9 @@ kernel void matmul_naive(GEMM_PARAMS_DEF(T)) {
instantiate_matmul_custom(float32, float);
"""
from tinygrad import Device, Tensor
from tinygrad import Device, Tensor, Context
if __name__ == "__main__":
# TODO: why isn't this type inferred?
device = Device["METAL"]
lib = device.compiler.compile(gemm)
prg = device.runtime("matmul_custom_float32", lib)
@@ -65,7 +65,10 @@ if __name__ == "__main__":
global_size=gsz, local_size=(32,1,1), vals=(N, N, N), wait=True)
print(f"{N*N*N*2/(et*1e9):2f} GFLOPS")
val = ((a@b).contiguous()-c).mean()
print(val.item())
for _ in range(5):
with Context(DEBUG=2):
ref = (a@b).realize()
print((ref-c).mean().item())