diff --git a/extra/gemm/halide_gemm.py b/extra/gemm/halide_gemm.py new file mode 100644 index 0000000000..acfd83d14e --- /dev/null +++ b/extra/gemm/halide_gemm.py @@ -0,0 +1,90 @@ +import numpy as np +import halide as hl +from tinygrad.helpers import Timing, getenv + +# HL_DEBUG_CODEGEN=1 +N = getenv("N", 1024) + +def gemm_pipeline(gpu=False): + # ---------------- Vars & Parameters ---------------- + i, j = hl.Var("i"), hl.Var("j") # output tile coordinates + + A = hl.InputBuffer(hl.Float(32), 2) # [M, K] + B = hl.InputBuffer(hl.Float(32), 2) # [K, N] + + A.dim(0).set_bounds(0, N) + A.dim(1).set_bounds(0, N) + B.dim(0).set_bounds(0, N) + B.dim(1).set_bounds(0, N) + + # ---------------- Definition ---------------- + + k = hl.RDom([(0, N)]) + + partial = hl.Func("partial") + partial[i, j] = 0.0 + partial[i, j] += A[i, k] * B[k, j] + + C = hl.Func("C") + C[i, j] = partial[i, j] + + if not gpu: + # ---------------- Schedule ---------------- + VEC = 16 + TILE_I = 64 + TILE_J = 64 + + io, jo, ii, ji = hl.Var("io"), hl.Var("jo"), hl.Var("ii"), hl.Var("ji") + C.update().tile(i, j, io, jo, ii, ji, TILE_I, TILE_J).fuse(io, jo, io).parallel(io).vectorize(ji, VEC) + else: + # ---------------- Schedule ---------------- + GRP_I = 8 # output tile size + GRP_J = 16 + + #partial.store_in(hl.MemoryType.Register) + #partial.update().unroll(k, 4) + + io, jo, ii, ji = hl.Var(), hl.Var(), hl.Var(), hl.Var() + C.gpu_tile(i, j, io, jo, ii, ji, GRP_I, GRP_J, hl.TailStrategy.RoundUp) + + return C, A, B + +if __name__ == "__main__": + pipe, A, B = gemm_pipeline(gpu=True) + + # NOTE: meteal does nothing + target = hl.get_host_target().with_feature(hl.TargetFeature.Metal) + + a_np = np.random.randn(N, N).astype(np.float32) + b_np = np.random.randn(N, N).astype(np.float32) + + # reverse order is correct! + a_hal = hl.Buffer(b_np) + b_hal = hl.Buffer(a_np) + A.set(a_hal) + B.set(b_hal) + + pipe.compile_to_lowered_stmt("/tmp/my_function.html", [A, B], hl.StmtOutputFormat.HTML, target=target) + #exit(0) + + c_hal = hl.Buffer(hl.Float(32), [N,N]) + with Timing("halide gemm "): + pipe.realize(c_hal, target) + c_hal.copy_to_host() + c_out = np.array(c_hal) + print(c_out) + + # tinygrad gets 60 ms with no BEAM, 20 ms with BEAM on CPU + with Timing("halide gemm "): + pipe.realize(c_hal, target) + c_hal.copy_to_host() + + # Check correctness + with Timing("numpy gemm "): + ref = a_np @ b_np + max_err = np.abs(ref - c_out).max() + print("Max absolute error:", max_err) + assert max_err < 1e-4, "GEMM result incorrect!" + + print("Pipeline ran on", target) + print("Success - GEMM Halide-Python output matches NumPy.")