make maximum split the grad like torch when equal (#738)

* make maximum split grad

* added test for maximum split grad when equal

* minor expr simplification

* (2-eq)/2 only once

* update test bc one more sum output child stays
This commit is contained in:
worldwalker2000
2023-04-14 00:17:46 -07:00
committed by GitHub
parent 06ed958abd
commit 552a048a33
3 changed files with 9 additions and 5 deletions

View File

@@ -91,6 +91,7 @@ class TestOps(unittest.TestCase):
def test_maximum(self):
helper_test_op([(45,65), (45,65)], torch.maximum, Tensor.maximum)
helper_test_op(None, torch.maximum, Tensor.maximum, vals=[[1., 2., 3., 4.], [1., 2., 3., 4.]])
def test_minimum(self):
helper_test_op([(45,65), (45,65)], torch.minimum, Tensor.minimum)
def test_add(self):