Uppercase N and M (standalone syntax change) (#9647)

This commit is contained in:
Ignacio Sica
2025-03-31 18:45:30 +08:00
committed by GitHub
parent aca0f1befb
commit baa67fd124

View File

@@ -25,8 +25,8 @@ def helper_realized_ast(r:Union[Tensor, list[Tensor]]) -> tuple[UOp, list[Buffer
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
return s[-1].ast, bufs
def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0):
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
def helper_tc_allclose(N:int, M:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0):
a, b = Tensor.rand(M, k, dtype=dtype_in), Tensor.rand(k, N, dtype=dtype_in)
np_a, np_b = a.numpy(), b.numpy()
r = a.matmul(b, dtype=dtype_out)
sched = r.schedule()
@@ -44,9 +44,9 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
else: tc_atol, tc_rtol = 5e-3, 1e-4
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0,
def helper_tc_ensure_uops_and_opts_count(N: int, M:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0,
ensure_triggered:bool=True):
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
a, b = Tensor.rand(M, k, dtype=dtype_in), Tensor.rand(k, N, dtype=dtype_in)
r = a.matmul(b, dtype=dtype_out)
sched = r.schedule()
realized_ast = sched[-1].ast