One hot in tensor.py (#3093)

* onehot in Tensor.py

* one_hot tests

* works for all shapes, not just 1

* pylint

* not a static method

* moved around, num_classes mandatory

* pylint

* pylint

* space & moving

* formatting

* moved tests
This commit is contained in:
SnakeOnex
2024-01-12 19:31:18 +01:00
committed by GitHub
parent 7086d77db1
commit 025fbf4e80
2 changed files with 8 additions and 0 deletions

View File

@@ -1488,6 +1488,12 @@ class TestOps(unittest.TestCase):
helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)),
lambda x,y: x.binary_crossentropy_logits(y.clip(0,1)))
def test_one_hot(self):
data = [1, 2, 4]
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6), lambda: Tensor(data).one_hot(6), forward_only=True)
data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]]
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8), lambda: Tensor(data).one_hot(8), forward_only=True)
if __name__ == '__main__':
np.random.seed(1337)
unittest.main(verbosity=2)