move Tensor.one_hot +1 to python (#15088)

This commit is contained in:
chenyu
2026-03-02 10:56:41 -05:00
committed by GitHub
parent dafbe9733a
commit 4008f7d4e8

View File

@@ -3265,7 +3265,7 @@ class Tensor(OpMixin):
```
"""
if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}")
if num_classes == -1: num_classes = int((self.max()+1).item())
if num_classes == -1: num_classes = int(self.max().item())+1
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,