mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
default num_classes value for one_hot (#6182)
* num_classes=-1 If num_classes set to -1, the number of classes will be inferred as one greater than the largest class value in the input tensor. * num_classes desc comment to explain num_classes default and what that means. * replacing ' with `
This commit is contained in:
@@ -2033,9 +2033,13 @@ class TestOps(unittest.TestCase):
|
||||
data = [1, 2, 4]
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6).type(torch.int32),
|
||||
lambda: Tensor(data).one_hot(6), forward_only=True)
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data)).type(torch.int32),
|
||||
lambda: Tensor(data).one_hot(), forward_only=True)
|
||||
data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]]
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8).type(torch.int32),
|
||||
lambda: Tensor(data).one_hot(8), forward_only=True)
|
||||
helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data)).type(torch.int32),
|
||||
lambda: Tensor(data).one_hot(), forward_only=True)
|
||||
|
||||
def test_masked_fill(self):
|
||||
helper_test_op([(32,10)], lambda x: x.masked_fill((x>0.1).detach(), -math.inf))
|
||||
|
||||
Reference in New Issue
Block a user