diff --git a/extra/gemm/asm/cdna/test.py b/extra/gemm/asm/cdna/test.py index d19f911a52..c0c4c78ef9 100644 --- a/extra/gemm/asm/cdna/test.py +++ b/extra/gemm/asm/cdna/test.py @@ -2,10 +2,8 @@ # VIZ=2 to profile import pathlib from tinygrad import Tensor, Device, dtypes, Context -from tinygrad.engine.realize import ExecItem, CompiledRunner -from tinygrad.renderer import ProgramSpec -from tinygrad.uop.ops import track_rewrites, UOp -from tinygrad.helpers import TracingKey, getenv +from tinygrad.uop.ops import UOp, Ops, KernelInfo +from tinygrad.helpers import getenv fp = pathlib.Path(__file__).parent/"gemm.s" @@ -23,8 +21,8 @@ import torch torch.manual_seed(0) A = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous() B = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous() -Bt = B.t().contiguous() # transpose B for the baseline gemm -C_torch = A@Bt +Bt = B.t().contiguous() # transpose B for the asm gemm +C_torch = A@B # ** copy buffers to AMD @@ -33,31 +31,33 @@ C_torch = A@Bt def from_torch(t:torch.Tensor) -> Tensor: return Tensor.from_blob(t.data_ptr(), t.shape, dtype=dtypes.bfloat16, device="cpu").to(Device.DEFAULT).realize() -C_tiny = Tensor.matmul(from_torch(A), from_torch(Bt), dtype=dtypes.float32).cast(dtypes.bfloat16) +C_tiny = from_torch(A) @ from_torch(B) C_asm = Tensor.empty_like(C_tiny) -C_asm.uop.buffer.allocate() + +# ** assembly custom kernel + +def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp: + lidx = UOp.special(THREADS_PER_WG, "lidx0") + gidx = UOp.special(NUM_WG, "gidx0") + + src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text()) + + sz = UOp.variable("SZ", 256, 8192) + wg = UOp.variable("WG", 1, 1024) + + sink = UOp.sink(C.base, A.base, B.base, sz, wg, lidx, gidx, arg=KernelInfo(name="gemm")) + return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src))) + +C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0] # ** run gemms -# baseline tinygrad -sched = C_tiny.schedule() -assert len(sched) == 1 -eis:list[ExecItem] = [sched[-1].lower()] -ast = sched[-1].ast - -# assembly gemm -@track_rewrites(name=lambda ret: TracingKey(ret.name, (ret.function_name,), ret)) -def get_asm_prg() -> ProgramSpec: - src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text()) - lib = Device[Device.DEFAULT].compiler.compile(src) - return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1], - globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)]) -eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(A).uop.buffer, from_torch(B).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG}, - prg=CompiledRunner(get_asm_prg()))) +sched = Tensor.schedule(C_tiny, C_asm) +eis = [si.lower() for si in sched] with Context(DEBUG=2): for ei in eis: - et = ei.run(wait=True) + et = ei.run({"SZ":N, "WG":NUM_WG}, wait=True) print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS") # ** correctness