Expand Operator (#327)

* replace broadcasting with expand

* Tensor, not self

* remove broadcasting from mlops

* delete useless A operator

* expand, not repeat

* remove A op

* expand on gpu

* binary_op doesn't broadcast anymore

* expand is still total junk, but the tests should pass
This commit is contained in:
George Hotz
2022-06-12 12:31:48 -07:00
committed by GitHub
parent 5cf7649eda
commit dcbca4fdf1
8 changed files with 127 additions and 102 deletions

View File

@@ -168,6 +168,10 @@ class TestOps(unittest.TestCase):
def test_detach(self):
helper_test_op([(4,3,6,6)], lambda x: x.detach(), lambda x: x.detach(), forward_only=True)
def test_expand(self):
arg = (4,3,2,6)
helper_test_op([(4,3,1,6)], lambda x: x.expand(arg), lambda x: x.expand(shape=arg))
def test_simple_conv2d(self):
helper_test_op([(1,1,9,9), (1,1,3,3)],
lambda x,w: torch.nn.functional.conv2d(x,w).relu(),