wmma: enable METAL half tensor cores and clean up cstyle (#3095)

* wmma: enable METAL half tensor cores and clean up cstyle

* revert simple_matmul rand changes and break line in tensor

* added metal fp16->fp32 tensor core
This commit is contained in:
Francis Lam
2024-01-12 13:25:28 -08:00
committed by GitHub
parent f96fc6e9d4
commit ddbdb52f77
7 changed files with 37 additions and 42 deletions

View File

@@ -87,10 +87,7 @@ class TestLinearizer(unittest.TestCase):
if tc.arch is not None and tc.arch != os.uname().machine: continue
a, b = Tensor.rand(tc.dims[0], tc.dims[2], dtype=tc.dtype_in), Tensor.rand(tc.dims[2], tc.dims[1], dtype=tc.dtype_in)
np_a, np_b = a.numpy(), b.numpy()
if tc.dtype_out != tc.dtype_in:
r = (a.reshape(tc.dims[0], 1, tc.dims[2]) * b.permute(1,0).reshape(1, tc.dims[1], tc.dims[2])).cast(tc.dtype_out).sum(axis=2)
else:
r = a @ b
r = a.matmul(b, acc_dtype=tc.dtype_out)
realized_ast, _ = helper_realized_ast(r)
k = Linearizer(realized_ast)
k.apply_tensor_cores(1)