mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-09 23:18:04 -05:00
Uppercase N and M (standalone syntax change) (#9647)
This commit is contained in:
@@ -25,8 +25,8 @@ def helper_realized_ast(r:Union[Tensor, list[Tensor]]) -> tuple[UOp, list[Buffer
|
||||
bufs = [Buffer((x).device, x.size, x.dtype).allocate() if i < len(s[-1].ast.src) else x for i,x in enumerate(s[-1].bufs)]
|
||||
return s[-1].ast, bufs
|
||||
|
||||
def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0):
|
||||
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
|
||||
def helper_tc_allclose(N:int, M:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0):
|
||||
a, b = Tensor.rand(M, k, dtype=dtype_in), Tensor.rand(k, N, dtype=dtype_in)
|
||||
np_a, np_b = a.numpy(), b.numpy()
|
||||
r = a.matmul(b, dtype=dtype_out)
|
||||
sched = r.schedule()
|
||||
@@ -44,9 +44,9 @@ def helper_tc_allclose(n:int, m:int, k:int, dtype_in:DType, dtype_out:DType, axi
|
||||
else: tc_atol, tc_rtol = 5e-3, 1e-4
|
||||
np.testing.assert_allclose(np_c, out, atol=tc_atol, rtol=tc_rtol)
|
||||
|
||||
def helper_tc_ensure_uops_and_opts_count(n: int, m:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0,
|
||||
def helper_tc_ensure_uops_and_opts_count(N: int, M:int, k:int, dtype_in:DType, dtype_out:DType, axis:int=0, tc_select:int=-1, tc_opt:int=0,
|
||||
ensure_triggered:bool=True):
|
||||
a, b = Tensor.rand(m, k, dtype=dtype_in), Tensor.rand(k, n, dtype=dtype_in)
|
||||
a, b = Tensor.rand(M, k, dtype=dtype_in), Tensor.rand(k, N, dtype=dtype_in)
|
||||
r = a.matmul(b, dtype=dtype_out)
|
||||
sched = r.schedule()
|
||||
realized_ast = sched[-1].ast
|
||||
|
||||
Reference in New Issue
Block a user