* passing with 1d indices

* passing all test

* cleanup

* using safe_numpy for scalar
This commit is contained in:
Steven Anderson
2023-06-12 10:13:29 -07:00
committed by GitHub
parent 613c74ca9f
commit e54b6c5e7f

View File

@@ -251,4 +251,12 @@ def NegativeLogLikelihoodLoss(input, target, weight=None, ignore_index=None, red
loss = (-mask * input).sum(axis=1) * (1 if weight is None else weight)
if reduction == "mean": return loss.mean() if weight is None else loss.sum() / weight.sum()
elif reduction == "sum": return loss.sum()
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
return loss.reshape(t_shape) if len(i_shape) != 3 else loss
def OneHot(indices, depth, values, axis=-1):
depth = int(safe_numpy(depth).item())
indices, rank = (indices.cast(dtypes.float32) < 0).where(indices+depth, indices), len(indices.shape)
if axis < 0: axis += rank + 1
ls, rs = indices.shape[0:axis], indices.shape[axis: rank]
cond = indices[:,None] == Tensor.arange(depth).reshape((1,) * len(ls) + (depth,) + (1,) * len(rs))
return cond.where(values[1], values[0]).cast(values.dtype)