RDNA3 fp16 assembly gemm 85 TFLOPS (#13990)

This commit is contained in:
qazal
2026-01-03 18:34:23 +09:00
committed by GitHub
parent 6242a9d151
commit bd55507ee4
4 changed files with 3128 additions and 4 deletions

View File

@@ -140,11 +140,11 @@ def hand_spec_kernel3():
return sink.sink(arg=KernelInfo(opts_to_apply=())).simplify()
def test_matmul(sink:UOp, N=N):
def test_matmul(sink:UOp, dtype=dtypes.float32, N=N):
rng = np.random.default_rng()
a = Tensor(rng.random((N, N), dtype=np.float32)-0.5)
b = Tensor(rng.random((N, N), dtype=np.float32)-0.5)
hc = Tensor.empty(N, N)
a = Tensor(rng.random((N, N), dtype=np.float32)-0.5, dtype=dtype)
b = Tensor(rng.random((N, N), dtype=np.float32)-0.5, dtype=dtype)
hc = Tensor.empty(N, N, dtype=dtype)
Tensor.realize(a, b, hc)
ei = ExecItem(sink, [t.uop.buffer for t in [hc, a, b]], prg=get_runner(Device.DEFAULT, sink))