mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-08 22:48:25 -05:00
mi350x gemm: use Tensor.custom_kernel in asm test (#13969)
* mi350x gemm: use Tensor.custom_kernel in asm test * A @ B for baseline
This commit is contained in:
@@ -2,10 +2,8 @@
|
|||||||
# VIZ=2 to profile
|
# VIZ=2 to profile
|
||||||
import pathlib
|
import pathlib
|
||||||
from tinygrad import Tensor, Device, dtypes, Context
|
from tinygrad import Tensor, Device, dtypes, Context
|
||||||
from tinygrad.engine.realize import ExecItem, CompiledRunner
|
from tinygrad.uop.ops import UOp, Ops, KernelInfo
|
||||||
from tinygrad.renderer import ProgramSpec
|
from tinygrad.helpers import getenv
|
||||||
from tinygrad.uop.ops import track_rewrites, UOp
|
|
||||||
from tinygrad.helpers import TracingKey, getenv
|
|
||||||
|
|
||||||
fp = pathlib.Path(__file__).parent/"gemm.s"
|
fp = pathlib.Path(__file__).parent/"gemm.s"
|
||||||
|
|
||||||
@@ -23,8 +21,8 @@ import torch
|
|||||||
torch.manual_seed(0)
|
torch.manual_seed(0)
|
||||||
A = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
|
A = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
|
||||||
B = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
|
B = (torch.randn(N, N, dtype=torch.float32, device="cpu") / scale).to(torch.bfloat16).contiguous()
|
||||||
Bt = B.t().contiguous() # transpose B for the baseline gemm
|
Bt = B.t().contiguous() # transpose B for the asm gemm
|
||||||
C_torch = A@Bt
|
C_torch = A@B
|
||||||
|
|
||||||
# ** copy buffers to AMD
|
# ** copy buffers to AMD
|
||||||
|
|
||||||
@@ -33,31 +31,33 @@ C_torch = A@Bt
|
|||||||
def from_torch(t:torch.Tensor) -> Tensor:
|
def from_torch(t:torch.Tensor) -> Tensor:
|
||||||
return Tensor.from_blob(t.data_ptr(), t.shape, dtype=dtypes.bfloat16, device="cpu").to(Device.DEFAULT).realize()
|
return Tensor.from_blob(t.data_ptr(), t.shape, dtype=dtypes.bfloat16, device="cpu").to(Device.DEFAULT).realize()
|
||||||
|
|
||||||
C_tiny = Tensor.matmul(from_torch(A), from_torch(Bt), dtype=dtypes.float32).cast(dtypes.bfloat16)
|
C_tiny = from_torch(A) @ from_torch(B)
|
||||||
C_asm = Tensor.empty_like(C_tiny)
|
C_asm = Tensor.empty_like(C_tiny)
|
||||||
C_asm.uop.buffer.allocate()
|
|
||||||
|
# ** assembly custom kernel
|
||||||
|
|
||||||
|
def custom_asm_gemm(C:UOp, A:UOp, B:UOp) -> UOp:
|
||||||
|
lidx = UOp.special(THREADS_PER_WG, "lidx0")
|
||||||
|
gidx = UOp.special(NUM_WG, "gidx0")
|
||||||
|
|
||||||
|
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
|
||||||
|
|
||||||
|
sz = UOp.variable("SZ", 256, 8192)
|
||||||
|
wg = UOp.variable("WG", 1, 1024)
|
||||||
|
|
||||||
|
sink = UOp.sink(C.base, A.base, B.base, sz, wg, lidx, gidx, arg=KernelInfo(name="gemm"))
|
||||||
|
return UOp(Ops.PROGRAM, src=(sink, UOp(Ops.DEVICE, arg=Device.DEFAULT), UOp(Ops.LINEAR, src=(*sink.src, sink)), UOp(Ops.SOURCE, arg=src)))
|
||||||
|
|
||||||
|
C_asm = Tensor.custom_kernel(C_asm, from_torch(A), from_torch(Bt), fxn=custom_asm_gemm)[0]
|
||||||
|
|
||||||
# ** run gemms
|
# ** run gemms
|
||||||
|
|
||||||
# baseline tinygrad
|
sched = Tensor.schedule(C_tiny, C_asm)
|
||||||
sched = C_tiny.schedule()
|
eis = [si.lower() for si in sched]
|
||||||
assert len(sched) == 1
|
|
||||||
eis:list[ExecItem] = [sched[-1].lower()]
|
|
||||||
ast = sched[-1].ast
|
|
||||||
|
|
||||||
# assembly gemm
|
|
||||||
@track_rewrites(name=lambda ret: TracingKey(ret.name, (ret.function_name,), ret))
|
|
||||||
def get_asm_prg() -> ProgramSpec:
|
|
||||||
src = (pathlib.Path(__file__).parent/"template.s").read_text().replace("INSTRUCTIONS", fp.read_text())
|
|
||||||
lib = Device[Device.DEFAULT].compiler.compile(src)
|
|
||||||
return ProgramSpec("gemm", src, Device.DEFAULT, ast, lib=lib, global_size=[NUM_WG, 1, 1], local_size=[THREADS_PER_WG, 1, 1],
|
|
||||||
globals=[0, 1, 2], vars=[UOp.variable("SZ", 256, 8192), UOp.variable("NUM_WG", 1, 1024)])
|
|
||||||
eis.append(ExecItem(ast, [C_asm.uop.buffer, from_torch(A).uop.buffer, from_torch(B).uop.buffer], fixedvars={"SZ":N, "NUM_WG":NUM_WG},
|
|
||||||
prg=CompiledRunner(get_asm_prg())))
|
|
||||||
|
|
||||||
with Context(DEBUG=2):
|
with Context(DEBUG=2):
|
||||||
for ei in eis:
|
for ei in eis:
|
||||||
et = ei.run(wait=True)
|
et = ei.run({"SZ":N, "WG":NUM_WG}, wait=True)
|
||||||
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
|
print(f"{(N*N*N*2 / et)*1e-12:.2f} REAL TFLOPS")
|
||||||
|
|
||||||
# ** correctness
|
# ** correctness
|
||||||
|
|||||||
Reference in New Issue
Block a user