clean up Tensor.dot (#7728)

more docs (similar to numpy) and removed many confusing  `-min(n2, 2)`
This commit is contained in:
chenyu
2024-11-15 18:21:15 -05:00
committed by GitHub
parent 4338c450ac
commit 22da31b223
2 changed files with 25 additions and 16 deletions

View File

@@ -879,15 +879,15 @@ class TestOps(unittest.TestCase):
helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot)
helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot)
helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot)
self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
def test_dot(self):
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5)
helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5)
self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
with self.assertRaises(AssertionError):
self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
with self.assertRaises(RuntimeError):
a = Tensor(3.14)
a.matmul(a)
def test_mulacc_with_zero_strides(self):
@@ -954,7 +954,7 @@ class TestOps(unittest.TestCase):
helper_test_op([(0), (0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7)
def test_broadcastdot(self):
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
with self.assertRaises(AssertionError):
with self.assertRaises(RuntimeError):
a = Tensor(3.14)
b = Tensor.ones(3,3)
a @ b