mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
clean up Tensor.dot (#7728)
more docs (similar to numpy) and removed many confusing `-min(n2, 2)`
This commit is contained in:
@@ -879,15 +879,15 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot)
|
||||
helper_test_op([(8,45,65), (65)], lambda x,y: x.matmul(y), Tensor.dot)
|
||||
helper_test_op([(65), (8,65,45)], lambda x,y: x.matmul(y), Tensor.dot)
|
||||
self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(4), (1,2)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
|
||||
self.helper_test_exception([(2,1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
|
||||
self.helper_test_exception([(1), (4)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
|
||||
def test_dot(self):
|
||||
helper_test_op([(45,65), (65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5)
|
||||
helper_test_op([(8,45,65), (8,65,100)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-5)
|
||||
self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=(RuntimeError, AssertionError))
|
||||
with self.assertRaises(AssertionError):
|
||||
self.helper_test_exception([(2, 4), (1, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
|
||||
self.helper_test_exception([(2, 1), (4, 3)], lambda x, y: x.matmul(y), Tensor.dot, expected=RuntimeError)
|
||||
with self.assertRaises(RuntimeError):
|
||||
a = Tensor(3.14)
|
||||
a.matmul(a)
|
||||
def test_mulacc_with_zero_strides(self):
|
||||
@@ -954,7 +954,7 @@ class TestOps(unittest.TestCase):
|
||||
helper_test_op([(0), (0)], lambda x,y: x.matmul(y), Tensor.dot, atol=1e-7)
|
||||
def test_broadcastdot(self):
|
||||
helper_test_op([(10,45,65), (65,45)], lambda x,y: x @ y, Tensor.dot, atol=1e-4)
|
||||
with self.assertRaises(AssertionError):
|
||||
with self.assertRaises(RuntimeError):
|
||||
a = Tensor(3.14)
|
||||
b = Tensor.ones(3,3)
|
||||
a @ b
|
||||
|
||||
Reference in New Issue
Block a user