mirror of
https://github.com/tinygrad/tinygrad.git
synced 2026-04-29 03:00:14 -04:00
One hot (#972)
* passing with 1d indices * passing all test * cleanup * using safe_numpy for scalar
This commit is contained in:
@@ -251,4 +251,12 @@ def NegativeLogLikelihoodLoss(input, target, weight=None, ignore_index=None, red
|
||||
loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight)
|
||||
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
|
||||
elif reduction == "sum": return loss.sum()
|
||||
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
|
||||
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
|
||||
|
||||
def OneHot(indices, depth, values, axis=-1):
|
||||
depth = int(safe_numpy(depth).item())
|
||||
indices, rank = (indices.cast(dtypes.float32) < 0).where(indices+depth, indices), len(indices.shape)
|
||||
if axis < 0: axis += rank + 1
|
||||
ls, rs = indices.shape[0:axis], indices.shape[axis: rank]
|
||||
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
|
||||
return cond.where(values[1], values[0]).cast(values.dtype)
|
||||
|
||||
Reference in New Issue
Block a user