diff --git a/extra/models/transformer.py b/extra/models/transformer.py index 3e89846376..4753e50731 100644 --- a/extra/models/transformer.py +++ b/extra/models/transformer.py @@ -51,7 +51,7 @@ class Transformer: maxlen_eye = Tensor.eye(x.shape[1]) maxlen_eye = maxlen_eye.unsqueeze(0).expand([bs, *maxlen_eye.shape]) - onehot_feat = x.one_hot(self.syms) + onehot_feat = x.int().one_hot(self.syms) onehot = maxlen_eye.cat(onehot_feat, dim=2).flatten(end_dim=1) diff --git a/extra/onnx.py b/extra/onnx.py index 8c3938f899..f7020cf894 100644 --- a/extra/onnx.py +++ b/extra/onnx.py @@ -509,7 +509,7 @@ def get_onnx_ops(): for i in range(-len(sizes), 0): reshape, index = [1] * X.ndim, indexes[i] reshape[i] = expand[i] = sizes[i] - low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())] + low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())] X = X.gather(i, low).lerp(X.gather(i, high), perc) if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented") return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X @@ -579,6 +579,7 @@ def get_onnx_ops(): def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1): # Scalar or Rank 1 tensor containing exactly one element depth = int(depth[0] if isinstance(depth, list) else depth) + indices = indices.int() indices = (indices < 0).where(indices+depth, indices) return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0]) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 7d8a6ee85d..303a747703 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1117,7 +1117,7 @@ class Tensor(SimpleMathTrait): case list() | tuple() | Tensor(): if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False) if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported") - index = (index.to(self.device) < 0).where(size, 0) + index # treat negative index values + index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values case int() | UOp(): # sint if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}") boundary = [index, index+1] if index >= 0 else [index+size, index+size+1] @@ -2453,7 +2453,7 @@ class Tensor(SimpleMathTrait): reshape[i] = expand[i] = size[i] if mode == "linear": index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1) - low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())] + low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())] x = x.gather(i, low).lerp(x.gather(i, high), perc) else: index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand) @@ -3558,6 +3558,7 @@ class Tensor(SimpleMathTrait): # helper function commonly used for indexing def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1): + if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}") offset = self.ndim - self._resolve_dim(dim) - 1 return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset) @@ -3572,6 +3573,7 @@ class Tensor(SimpleMathTrait): print(t.one_hot(5).numpy()) ``` """ + if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}") if num_classes == -1: num_classes = (self.max()+1).item() return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)