_one_hot_along_dim input needs to be int (#9179)

* _one_hot_along_dim input needs to be int

indexing and onehot compare with arange, and non-int dtype is likely a bug
This commit is contained in:
chenyu
2025-02-20 09:00:43 -05:00
committed by GitHub
parent bf36967883
commit 1692087db5
3 changed files with 7 additions and 4 deletions

View File

@@ -51,7 +51,7 @@ class Transformer:
maxlen_eye = Tensor.eye(x.shape[1])
maxlen_eye = maxlen_eye.unsqueeze(0).expand([bs, *maxlen_eye.shape])
onehot_feat = x.one_hot(self.syms)
onehot_feat = x.int().one_hot(self.syms)
onehot = maxlen_eye.cat(onehot_feat, dim=2).flatten(end_dim=1)

View File

@@ -509,7 +509,7 @@ def get_onnx_ops():
for i in range(-len(sizes), 0):
reshape, index = [1] * X.ndim, indexes[i]
reshape[i] = expand[i] = sizes[i]
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())]
X = X.gather(i, low).lerp(X.gather(i, high), perc)
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X
@@ -579,6 +579,7 @@ def get_onnx_ops():
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
# Scalar or Rank 1 tensor containing exactly one element
depth = int(depth[0] if isinstance(depth, list) else depth)
indices = indices.int()
indices = (indices < 0).where(indices+depth, indices)
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])

View File

@@ -1117,7 +1117,7 @@ class Tensor(SimpleMathTrait):
case list() | tuple() | Tensor():
if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
index = (index.to(self.device) < 0).where(size, 0) + index # treat negative index values
index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values
case int() | UOp(): # sint
if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
@@ -2453,7 +2453,7 @@ class Tensor(SimpleMathTrait):
reshape[i] = expand[i] = size[i]
if mode == "linear":
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())]
x = x.gather(i, low).lerp(x.gather(i, high), perc)
else:
index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand)
@@ -3558,6 +3558,7 @@ class Tensor(SimpleMathTrait):
# helper function commonly used for indexing
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
offset = self.ndim - self._resolve_dim(dim) - 1
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
@@ -3572,6 +3573,7 @@ class Tensor(SimpleMathTrait):
print(t.one_hot(5).numpy())
```
"""
if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}")
if num_classes == -1: num_classes = (self.max()+1).item()
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)