From 4008f7d4e85a8162e1d800bc7e4c31c550bcd6a9 Mon Sep 17 00:00:00 2001 From: chenyu Date: Mon, 2 Mar 2026 10:56:41 -0500 Subject: [PATCH] move Tensor.one_hot +1 to python (#15088) --- tinygrad/tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 244de0c731..e8f6e869b1 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -3265,7 +3265,7 @@ class Tensor(OpMixin): ``` """ if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}") - if num_classes == -1: num_classes = int((self.max()+1).item()) + if num_classes == -1: num_classes = int(self.max().item())+1 return self[..., None]._one_hot_along_dim(num_classes).where(1, 0) def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tensor|None=None, dropout_p:float=0.0,