mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
move Tensor.one_hot +1 to python (#15088)
This commit is contained in:
@@ -3265,7 +3265,7 @@ class Tensor(OpMixin):
|
||||
```
|
||||
"""
|
||||
if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}")
|
||||
if num_classes == -1: num_classes = int((self.max()+1).item())
|
||||
if num_classes == -1: num_classes = int(self.max().item())+1
|
||||
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
||||
|
||||
def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,
|
||||
|
||||
Reference in New Issue
Block a user