diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 244de0c731..e8f6e869b1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3265,7 +3265,7 @@ class Tensor(OpMixin): ``` """ if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}") - if num_classes == -1: num_classes = int((self.max()+1).item()) + if num_classes == -1: num_classes = int(self.max().item())+1 return self[..., None]._one_hot_along_dim(num_classes).where(1, 0) def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,