mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-01-07 22:23:55 -05:00
_one_hot_along_dim input needs to be int (#9179)
* _one_hot_along_dim input needs to be int indexing and onehot compare with arange, and non-int dtype is likely a bug
This commit is contained in:
@@ -51,7 +51,7 @@ class Transformer:
|
||||
maxlen_eye = Tensor.eye(x.shape[1])
|
||||
maxlen_eye = maxlen_eye.unsqueeze(0).expand([bs, *maxlen_eye.shape])
|
||||
|
||||
onehot_feat = x.one_hot(self.syms)
|
||||
onehot_feat = x.int().one_hot(self.syms)
|
||||
|
||||
onehot = maxlen_eye.cat(onehot_feat, dim=2).flatten(end_dim=1)
|
||||
|
||||
|
||||
@@ -509,7 +509,7 @@ def get_onnx_ops():
|
||||
for i in range(-len(sizes), 0):
|
||||
reshape, index = [1] * X.ndim, indexes[i]
|
||||
reshape[i] = expand[i] = sizes[i]
|
||||
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
||||
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())]
|
||||
X = X.gather(i, low).lerp(X.gather(i, high), perc)
|
||||
if mode == "cubic": raise NotImplementedError("cubic interpolation is not implemented")
|
||||
return X.permute(*[perm.index(i) for i in range(len(perm))]) if perm else X
|
||||
@@ -579,6 +579,7 @@ def get_onnx_ops():
|
||||
def OneHot(indices:Tensor, depth:float|int|list, values:Tensor, axis:int=-1):
|
||||
# Scalar or Rank 1 tensor containing exactly one element
|
||||
depth = int(depth[0] if isinstance(depth, list) else depth)
|
||||
indices = indices.int()
|
||||
indices = (indices < 0).where(indices+depth, indices)
|
||||
return indices[:, None]._one_hot_along_dim(depth, dim=axis).where(values[1], values[0])
|
||||
|
||||
|
||||
@@ -1117,7 +1117,7 @@ class Tensor(SimpleMathTrait):
|
||||
case list() | tuple() | Tensor():
|
||||
if not isinstance(index, Tensor): index = Tensor(index, self.device, requires_grad=False)
|
||||
if not dtypes.is_int(index.dtype): raise IndexError(f"index dtype {index.dtype} is not supported")
|
||||
index = (index.to(self.device) < 0).where(size, 0) + index # treat negative index values
|
||||
index = (index.to(self.device) < 0).where(index+size, index) # treat negative index values
|
||||
case int() | UOp(): # sint
|
||||
if index >= size or index < -size: raise IndexError(f"{index=} is out of bounds with {size=}")
|
||||
boundary = [index, index+1] if index >= 0 else [index+size, index+size+1]
|
||||
@@ -2453,7 +2453,7 @@ class Tensor(SimpleMathTrait):
|
||||
reshape[i] = expand[i] = size[i]
|
||||
if mode == "linear":
|
||||
index = (scale*arr if align_corners else (scale*(arr+0.5))-0.5).clip(0, self.shape[i]-1)
|
||||
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor(), index.ceil(), index - index.floor())]
|
||||
low, high, perc = [y.reshape(reshape).expand(expand) for y in (index.floor().int(), index.ceil().int(), index - index.floor())]
|
||||
x = x.gather(i, low).lerp(x.gather(i, high), perc)
|
||||
else:
|
||||
index = (scale*(arr+0.5) if mode=="nearest-exact" else scale*arr).cast(dtypes.int32).reshape(reshape).expand(expand)
|
||||
@@ -3558,6 +3558,7 @@ class Tensor(SimpleMathTrait):
|
||||
|
||||
# helper function commonly used for indexing
|
||||
def _one_hot_along_dim(self:Tensor, num_classes:sint, dim:int=-1):
|
||||
if not dtypes.is_int(self.dtype): raise RuntimeError(f"_one_hot_along_dim expects int index tensor, getting {self.dtype}")
|
||||
offset = self.ndim - self._resolve_dim(dim) - 1
|
||||
return self == Tensor.arange(num_classes, device=self.device, requires_grad=False).reshape((num_classes,) + (1,) * offset)
|
||||
|
||||
@@ -3572,6 +3573,7 @@ class Tensor(SimpleMathTrait):
|
||||
print(t.one_hot(5).numpy())
|
||||
```
|
||||
"""
|
||||
if not dtypes.is_int(self.dtype): raise RuntimeError(f"expect integer dtype, getting {self.dtype=}")
|
||||
if num_classes == -1: num_classes = (self.max()+1).item()
|
||||
return self[..., None]._one_hot_along_dim(num_classes).where(1, 0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user