diff --git a/test/test_ops.py b/test/test_ops.py index 8b615ae038..2d423bb1f4 100644 --- a/test/test_ops.py +++ b/test/test_ops.py @@ -1488,6 +1488,12 @@ class TestOps(unittest.TestCase): helper_test_op([(32,10), (32,10)], lambda x,y: torch.nn.functional.binary_cross_entropy(x.sigmoid(),torch.clip(y,0,1)), lambda x,y: x.binary_crossentropy_logits(y.clip(0,1))) + def test_one_hot(self): + data = [1, 2, 4] + helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 6), lambda: Tensor(data).one_hot(6), forward_only=True) + data = [[[1, 2, 3], [0, 3, 5]], [[1, 2, 3], [0, 3, 5]]] + helper_test_op([], lambda: torch.nn.functional.one_hot(torch.tensor(data), 8), lambda: Tensor(data).one_hot(8), forward_only=True) + if __name__ == '__main__': np.random.seed(1337) unittest.main(verbosity=2) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 89f85ac4c0..3bbac7f813 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -893,6 +893,8 @@ class Tensor: if not Tensor.training or p == 0: return self return self * (Tensor.rand(*self.shape, requires_grad=False, device=self.device) >= p) * (1/(1.0 - p)) + def one_hot(self, num_classes:int, **kwargs) -> Tensor: return Tensor.where(self[..., None] == Tensor.arange(num_classes), 1, 0, **kwargs) + def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Optional[Tensor]=None, dropout_p:float=0.0, is_causal:bool=False) -> Tensor: # noqa: E501 # NOTE: it works if key, value have symbolic shape assert all_int(self.shape), f"does not support symbolic shape {self.shape}"